diff --git a/vlib/v/ast/ast.v b/vlib/v/ast/ast.v index 3ee4345456..1ef726efff 100644 --- a/vlib/v/ast/ast.v +++ b/vlib/v/ast/ast.v @@ -7,7 +7,7 @@ import v.token import v.table import v.errors -pub type TypeDecl = AliasTypeDecl | FnTypeDecl | SumTypeDecl +pub type TypeDecl = AliasTypeDecl | FnTypeDecl | SumTypeDecl | UnionSumTypeDecl pub type Expr = AnonFn | ArrayInit | AsCast | Assoc | AtExpr | BoolLiteral | CTempVar | CallExpr | CastExpr | ChanInit | CharLiteral | Comment | ComptimeCall | ConcatExpr | EnumVal | @@ -362,11 +362,23 @@ pub: is_arg bool // fn args should not be autofreed pub mut: typ table.Type + sum_type_cast table.Type pos token.Position is_used bool is_changed bool // to detect mutable vars that are never changed } +// used for smartcasting only +// struct fields change type in scopes +pub struct ScopeStructField { +pub: + struct_type table.Type // type of struct + name string + pos token.Position + typ table.Type + sum_type_cast table.Type +} + pub struct GlobalField { pub: name string @@ -522,7 +534,7 @@ pub: mut_name bool // `if mut name is` pub mut: stmts []Stmt - smartcast bool // true when cond is `x is SumType`, set in checker.if_expr + smartcast bool // true when cond is `x is SumType`, set in checker.if_expr // no longer needed with union sum types TODO: remove } pub struct UnsafeExpr { @@ -733,6 +745,16 @@ pub: pos token.Position } +// New implementation of sum types +pub struct UnionSumTypeDecl { +pub: + name string + is_pub bool + pos token.Position +pub mut: + sub_types []table.Type +} + pub struct FnTypeDecl { pub: name string @@ -1115,7 +1137,7 @@ pub fn (stmt Stmt) position() token.Position { AssertStmt, AssignStmt, Block, BranchStmt, CompFor, ConstDecl, DeferStmt, EnumDecl, ExprStmt, FnDecl, ForCStmt, ForInStmt, ForStmt, GotoLabel, GotoStmt, Import, Return, StructDecl, GlobalDecl, HashStmt, InterfaceDecl, Module, SqlStmt { return stmt.pos } GoStmt { return stmt.call_expr.position() } TypeDecl { match stmt { - AliasTypeDecl, FnTypeDecl, SumTypeDecl { return stmt.pos } + AliasTypeDecl, FnTypeDecl, SumTypeDecl, UnionSumTypeDecl { return stmt.pos } } } // Please, do NOT use else{} here. // This match is exhaustive *on purpose*, to help force diff --git a/vlib/v/ast/scope.v b/vlib/v/ast/scope.v index 52c1ab4b38..0d115c2ae7 100644 --- a/vlib/v/ast/scope.v +++ b/vlib/v/ast/scope.v @@ -8,11 +8,12 @@ import v.table pub struct Scope { pub mut: // mut: - objects map[string]ScopeObject - parent &Scope - children []&Scope - start_pos int - end_pos int + objects map[string]ScopeObject + struct_fields []ScopeStructField + parent &Scope + children []&Scope + start_pos int + end_pos int } pub fn new_scope(parent &Scope, start_pos int) &Scope { @@ -48,6 +49,20 @@ pub fn (s &Scope) find(name string) ?ScopeObject { return none } +pub fn (s &Scope) find_struct_field(struct_type table.Type, field_name string) ?ScopeStructField { + for sc := s; true; sc = sc.parent { + for field in sc.struct_fields { + if field.struct_type == struct_type && field.name == field_name { + return field + } + } + if isnil(sc.parent) { + break + } + } + return none +} + pub fn (s &Scope) is_known(name string) bool { if _ := s.find(name) { return true @@ -96,6 +111,15 @@ pub fn (mut s Scope) update_var_type(name string, typ table.Type) { } } +pub fn (mut s Scope) register_struct_field(field ScopeStructField) { + for f in s.struct_fields { + if f.struct_type == field.struct_type && f.name == field.name { + return + } + } + s.struct_fields << field +} + pub fn (mut s Scope) register(name string, obj ScopeObject) { if name == '_' { return @@ -163,6 +187,9 @@ pub fn (sc &Scope) show(depth int, max_depth int) string { else {} } } + for field in sc.struct_fields { + out += '$indent * struct_field: $field.struct_type $field.name - $field.typ\n' + } if max_depth == 0 || depth < max_depth - 1 { for i, _ in sc.children { out += sc.children[i].show(depth + 1, max_depth) diff --git a/vlib/v/checker/checker.v b/vlib/v/checker/checker.v index db5d9514cf..0ba5ab149c 100644 --- a/vlib/v/checker/checker.v +++ b/vlib/v/checker/checker.v @@ -28,40 +28,41 @@ const ( ) pub struct Checker { - pref &pref.Preferences // Preferences shared from V struct + pref &pref.Preferences // Preferences shared from V struct pub mut: - table &table.Table - file &ast.File = 0 - nr_errors int - nr_warnings int - errors []errors.Error - warnings []errors.Warning - error_lines []int // to avoid printing multiple errors for the same line - expected_type table.Type - cur_fn &ast.FnDecl // current function - const_decl string - const_deps []string - const_names []string - global_names []string - locked_names []string // vars that are currently locked - rlocked_names []string // vars that are currently read-locked - in_for_count int // if checker is currently in a for loop + table &table.Table + file &ast.File = 0 + nr_errors int + nr_warnings int + errors []errors.Error + warnings []errors.Warning + error_lines []int // to avoid printing multiple errors for the same line + expected_type table.Type + cur_fn &ast.FnDecl // current function + const_decl string + const_deps []string + const_names []string + global_names []string + locked_names []string // vars that are currently locked + rlocked_names []string // vars that are currently read-locked + in_for_count int // if checker is currently in a for loop // checked_ident string // to avoid infinite checker loops - returns bool - scope_returns bool - mod string // current module name - is_builtin_mod bool // are we in `builtin`? - inside_unsafe bool - skip_flags bool // should `#flag` and `#include` be skipped - cur_generic_type table.Type + returns bool + scope_returns bool + mod string // current module name + is_builtin_mod bool // are we in `builtin`? + inside_unsafe bool + skip_flags bool // should `#flag` and `#include` be skipped + cur_generic_type table.Type mut: - expr_level int // to avoid infinite recursion segfaults due to compiler bugs - inside_sql bool // to handle sql table fields pseudo variables - cur_orm_ts table.TypeSymbol - error_details []string - generic_funcs []&ast.FnDecl - vmod_file_content string // needed for @VMOD_FILE, contents of the file, *NOT its path* - vweb_gen_types []table.Type // vweb route checks + expr_level int // to avoid infinite recursion segfaults due to compiler bugs + inside_sql bool // to handle sql table fields pseudo variables + cur_orm_ts table.TypeSymbol + error_details []string + generic_funcs []&ast.FnDecl + vmod_file_content string // needed for @VMOD_FILE, contents of the file, *NOT its path** + vweb_gen_types []table.Type // vweb route checks + prevent_sum_type_unwrapping_once bool // needed for assign new values to sum type, stopping unwrapping then } pub fn new_checker(table &table.Table, pref &pref.Preferences) Checker { @@ -318,6 +319,17 @@ pub fn (mut c Checker) type_decl(node ast.TypeDecl) { } } } + ast.UnionSumTypeDecl { + c.check_valid_pascal_case(node.name, 'sum type', node.pos) + for typ in node.sub_types { + mut sym := c.table.get_type_symbol(typ) + if sym.kind == .placeholder { + c.error("type `$sym.source_name` doesn't exist", node.pos) + } else if sym.kind == .interface_ { + c.error('sum type cannot hold an interface', node.pos) + } + } + } } } @@ -786,9 +798,15 @@ pub fn (mut c Checker) infix_expr(mut infix_expr ast.InfixExpr) table.Type { c.error('$infix_expr.op.str(): type `$typ_sym.source_name` does not exist', type_expr.pos) } - if left.kind !in [.interface_, .sum_type] { + if left.kind !in [.interface_, .sum_type, .union_sum_type] { c.error('`$infix_expr.op.str()` can only be used with interfaces and sum types', infix_expr.pos) + } else if left.kind == .union_sum_type { + info := left.info as table.UnionSumType + if type_expr.typ !in info.variants { + c.error('`$left.source_name` has no variant `$right.source_name`', + infix_expr.pos) + } } return table.bool_type } @@ -1717,6 +1735,8 @@ fn is_expr_panic_or_exit(expr ast.Expr) bool { } pub fn (mut c Checker) selector_expr(mut selector_expr ast.SelectorExpr) table.Type { + prevent_sum_type_unwrapping_once := c.prevent_sum_type_unwrapping_once + c.prevent_sum_type_unwrapping_once = false // T.name, typeof(expr).name mut name_type := 0 match selector_expr.expr as left { @@ -1746,9 +1766,9 @@ pub fn (mut c Checker) selector_expr(mut selector_expr ast.SelectorExpr) table.T return table.void_type } selector_expr.expr_type = typ - sym := c.table.get_type_symbol(c.unwrap_generic(typ)) field_name := selector_expr.field_name - // variadic + utyp := c.unwrap_generic(typ) + sym := c.table.get_type_symbol(utyp) if typ.has_flag(.variadic) || sym.kind == .array_fixed || sym.kind == .chan { if field_name == 'len' || (sym.kind == .chan && field_name == 'cap') { selector_expr.typ = table.int_type @@ -1760,6 +1780,15 @@ pub fn (mut c Checker) selector_expr(mut selector_expr ast.SelectorExpr) table.T if sym.mod != c.mod && !field.is_pub && sym.language != .c { c.error('field `${sym.source_name}.$field_name` is not public', selector_expr.pos) } + field_sym := c.table.get_type_symbol(field.typ) + if field_sym.kind == .union_sum_type { + if !prevent_sum_type_unwrapping_once { + scope := c.file.scope.innermost(selector_expr.pos.pos) + if scope_field := scope.find_struct_field(utyp, field_name) { + return scope_field.sum_type_cast + } + } + } selector_expr.typ = field.typ return field.typ } else { @@ -1982,6 +2011,9 @@ pub fn (mut c Checker) assign_stmt(mut assign_stmt ast.AssignStmt) { is_blank_ident := left.is_blank_ident() mut left_type := table.void_type if !is_decl && !is_blank_ident { + if left is ast.Ident || left is ast.SelectorExpr { + c.prevent_sum_type_unwrapping_once = true + } left_type = c.expr(left) c.expected_type = c.unwrap_generic(left_type) } @@ -1992,7 +2024,7 @@ pub fn (mut c Checker) assign_stmt(mut assign_stmt ast.AssignStmt) { } } right := if i < assign_stmt.right.len { assign_stmt.right[i] } else { assign_stmt.right[0] } - right_type := assign_stmt.right_types[i] + mut right_type := assign_stmt.right_types[i] if is_decl { left_type = c.table.mktyp(right_type) if left_type == table.int_type { @@ -2159,15 +2191,58 @@ pub fn (mut c Checker) assign_stmt(mut assign_stmt ast.AssignStmt) { } else {} } - // Dual sides check (compatibility check) if !is_blank_ident && right_sym.kind != .placeholder { - c.check_expected(right_type_unwrapped, left_type_unwrapped) or { + // Assign to sum type if ordinary value + mut final_left_type := left_type_unwrapped + mut scope := c.file.scope.innermost(left.position().pos) + match left { + ast.SelectorExpr { + if _ := scope.find_struct_field(left.expr_type, left.field_name) { + final_left_type = right_type_unwrapped + mut inner_scope := c.open_scope(mut scope, left.pos.pos) + inner_scope.register_struct_field(ast.ScopeStructField{ + struct_type: left.expr_type + name: left.field_name + typ: final_left_type + sum_type_cast: right_type_unwrapped + pos: left.pos + }) + } + } + ast.Ident { + if v := scope.find_var(left.name) { + if v.sum_type_cast != 0 && + c.table.sumtype_has_variant(final_left_type, right_type_unwrapped) { + final_left_type = right_type_unwrapped + mut inner_scope := c.open_scope(mut scope, left.pos.pos) + inner_scope.register(left.name, ast.Var{ + name: left.name + typ: final_left_type + pos: left.pos + is_used: true + is_mut: left.is_mut + sum_type_cast: right_type_unwrapped + }) + } + } + } + else {} + } + // Dual sides check (compatibility check) + c.check_expected(right_type_unwrapped, final_left_type) or { c.error('cannot assign to `$left`: $err', right.position()) } } } } +fn (mut c Checker) open_scope(mut parent ast.Scope, start_pos int) &ast.Scope { + mut s := ast.new_scope(parent, start_pos) + s.end_pos = parent.end_pos + parent.children << s + return s +} + fn (mut c Checker) check_array_init_para_type(para string, expr ast.Expr, pos token.Position) { sym := c.table.get_type_symbol(c.expr(expr)) if sym.kind !in [.int, .any_int] { @@ -2704,7 +2779,7 @@ pub fn (mut c Checker) expr(node ast.Expr) table.Type { node.expr_type = c.expr(node.expr) expr_type_sym := c.table.get_type_symbol(node.expr_type) type_sym := c.table.get_type_symbol(node.typ) - if expr_type_sym.kind == .sum_type { + if expr_type_sym.kind == .sum_type || expr_type_sym.kind == .union_sum_type { if type_sym.kind == .placeholder { // Unknown type used in the right part of `as` c.error('unknown type `$type_sym.source_name`', node.pos) @@ -2716,13 +2791,15 @@ pub fn (mut c Checker) expr(node ast.Expr) table.Type { } } else { mut s := 'cannot cast non-sum type `$expr_type_sym.source_name` using `as`' - if type_sym.kind == .sum_type { + if type_sym.kind == .sum_type || expr_type_sym.kind == .union_sum_type { s += ' - use e.g. `${type_sym.source_name}(some_expr)` instead.' } c.error(s, node.pos) } + if expr_type_sym.kind == .union_sum_type { + return node.typ + } return node.typ.to_ptr() - // return node.typ } ast.Assoc { scope := c.file.scope.innermost(node.pos.pos) @@ -2968,7 +3045,7 @@ pub fn (mut c Checker) cast_expr(mut node ast.CastExpr) table.Type { c.error('can not cast type `byte` to string, use `${node.expr.str()}.str()` instead.', node.pos) } - if to_type_sym.kind == .sum_type { + if to_type_sym.kind == .sum_type || to_type_sym.kind == .union_sum_type { if node.expr_type in [table.any_int_type, table.any_flt_type] { node.expr_type = c.promote_num(node.expr_type, if node.expr_type == table.any_int_type { table.int_type } else { table.f64_type }) } @@ -3140,7 +3217,9 @@ pub fn (mut c Checker) ident(mut ident ast.Ident) table.Type { c.error('undefined variable `$ident.name` (used before declaration)', ident.pos) } - mut typ := obj.typ + is_sum_type_cast := obj.sum_type_cast != 0 && !c.prevent_sum_type_unwrapping_once + c.prevent_sum_type_unwrapping_once = false + mut typ := if is_sum_type_cast { obj.sum_type_cast } else { obj.typ } if typ == 0 { if obj.expr is ast.Ident { inner_ident := obj.expr as ast.Ident @@ -3164,7 +3243,9 @@ pub fn (mut c Checker) ident(mut ident ast.Ident) table.Type { // typ = c.cur_generic_type // } // } else { - obj.typ = typ + if !is_sum_type_cast { + obj.typ = typ + } ident.obj = obj1 // unwrap optional (`println(x)`) if is_optional { @@ -3270,7 +3351,7 @@ pub fn (mut c Checker) match_expr(mut node ast.MatchExpr) table.Type { c.error('compiler bug: match 0 cond type', node.pos) } cond_type_sym := c.table.get_type_symbol(cond_type) - if cond_type_sym.kind !in [.sum_type, .interface_] { + if cond_type_sym.kind !in [.sum_type, .interface_, .union_sum_type] { node.is_sum_type = false } c.match_exprs(mut node, cond_type_sym) @@ -3379,9 +3460,37 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, type_sym table.TypeSymbol continue } match expr { - ast.Type { key = c.table.type_to_str(expr.typ) } - ast.EnumVal { key = expr.val } - else { key = expr.str() } + ast.Type { + key = c.table.type_to_str(expr.typ) + // smart cast only if one type is given (currently) // TODO make this work if types have same fields + if branch.exprs.len == 1 && cond_type_sym.kind == .union_sum_type { + mut scope := c.file.scope.innermost(branch.pos.pos) + match node.cond as node_cond { + ast.SelectorExpr { scope.register_struct_field(ast.ScopeStructField{ + struct_type: node_cond.expr_type + name: node_cond.field_name + typ: node.cond_type + sum_type_cast: expr.typ + pos: node_cond.pos + }) } + ast.Ident { scope.register(node.var_name, ast.Var{ + name: node.var_name + typ: node.cond_type + pos: node_cond.pos + is_used: true + is_mut: node.is_mut + sum_type_cast: expr.typ + }) } + else {} + } + } + } + ast.EnumVal { + key = expr.val + } + else { + key = expr.str() + } } val := if key in branch_exprs { branch_exprs[key] } else { 0 } if val == 1 { @@ -3390,11 +3499,23 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, type_sym table.TypeSymbol c.expected_type = node.cond_type expr_type := c.expr(expr) if cond_type_sym.kind == .interface_ { - c.type_implements(expr_type, c.expected_type, branch.pos) + // TODO + // This generates a memory issue with TCC + // Needs to be checked later when TCC errors are fixed + // Current solution is to move expr.position() to its own statement + // c.type_implements(expr_type, c.expected_type, expr.position()) + expr_pos := expr.position() + c.type_implements(expr_type, c.expected_type, expr_pos) + } else if cond_type_sym.info is table.UnionSumType as info { + if expr_type !in info.variants { + expr_str := c.table.type_to_str(expr_type) + expect_str := c.table.type_to_str(c.expected_type) + c.error('`$expect_str` has no variant `$expr_str`', expr.position()) + } } else if !c.check_types(expr_type, c.expected_type) { expr_str := c.table.type_to_str(expr_type) expect_str := c.table.type_to_str(c.expected_type) - c.error('cannot match `$expr_str` with `$expect_str` condition', branch.pos) + c.error('cannot match `$expr_str` with `$expect_str` condition', expr.position()) } branch_exprs[key] = val + 1 } @@ -3415,6 +3536,15 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, type_sym table.TypeSymbol } } } + table.UnionSumType { + for v in info.variants { + v_str := c.table.type_to_str(v) + if v_str !in branch_exprs { + is_exhaustive = false + unhandled << '`$v_str`' + } + } + } // table.Enum { for v in info.vals { @@ -3609,28 +3739,48 @@ pub fn (mut c Checker) if_expr(mut node ast.IfExpr) table.Type { .variable } else { true } // Register shadow variable or `as` variable with actual type if is_variable { - if left_sym.kind in [.sum_type, .interface_] && branch.left_as_name.len > 0 { + if left_sym.kind in [.sum_type, .interface_, .union_sum_type] { mut is_mut := false mut scope := c.file.scope.innermost(branch.body_pos.pos) if infix.left is ast.Ident as infix_left { - if var := scope.find_var(infix_left.name) { - is_mut = var.is_mut + if v := scope.find_var(infix_left.name) { + is_mut = v.is_mut } - } else if infix.left is ast.SelectorExpr { - selector := infix.left as ast.SelectorExpr + if left_sym.kind == .union_sum_type { + scope.register(branch.left_as_name, ast.Var{ + name: branch.left_as_name + typ: infix.left_type + sum_type_cast: right_expr.typ + pos: infix.left.position() + is_used: true + is_mut: is_mut + }) + } + } else if infix.left is ast.SelectorExpr as selector { field := c.table.struct_find_field(left_sym, selector.field_name) or { table.Field{} } is_mut = field.is_mut + if left_sym.kind == .union_sum_type { + scope.register_struct_field(ast.ScopeStructField{ + struct_type: selector.expr_type + name: selector.field_name + typ: infix.left_type + sum_type_cast: right_expr.typ + pos: infix.left.position() + }) + } + } + if left_sym.kind != .union_sum_type && branch.left_as_name.len > 0 { + scope.register(branch.left_as_name, ast.Var{ + name: branch.left_as_name + typ: right_expr.typ.to_ptr() + pos: infix.left.position() + is_used: true + is_mut: is_mut + }) + node.branches[i].smartcast = true } - scope.register(branch.left_as_name, ast.Var{ - name: branch.left_as_name - typ: right_expr.typ.to_ptr() - pos: infix.left.position() - is_used: true - is_mut: is_mut - }) - node.branches[i].smartcast = true } } } diff --git a/vlib/v/checker/tests/match_expr_and_expected_type_error.out b/vlib/v/checker/tests/match_expr_and_expected_type_error.out index 21e27503bf..2d4089d19d 100644 --- a/vlib/v/checker/tests/match_expr_and_expected_type_error.out +++ b/vlib/v/checker/tests/match_expr_and_expected_type_error.out @@ -2,13 +2,13 @@ vlib/v/checker/tests/match_expr_and_expected_type_error.vv:3:3: error: cannot ma 1 | ch := `a` 2 | match ch { 3 | 'a' {} - | ~~~~~ + | ~~~ 4 | else {} 5 | } vlib/v/checker/tests/match_expr_and_expected_type_error.vv:9:3: error: cannot match `string` with `int` condition 7 | i := 123 8 | match i { 9 | 'a' {} - | ~~~~~ + | ~~~ 10 | else {} 11 | } diff --git a/vlib/v/checker/tests/match_invalid_type.out b/vlib/v/checker/tests/match_invalid_type.out index ab53b77029..8f01ff161b 100644 --- a/vlib/v/checker/tests/match_invalid_type.out +++ b/vlib/v/checker/tests/match_invalid_type.out @@ -2,7 +2,7 @@ vlib/v/checker/tests/match_invalid_type.vv:5:3: error: cannot match `byte` with 3 | fn sum() { 4 | match IoS(1) { 5 | byte { - | ~~~~~~ + | ~~~~ 6 | println('not cool') 7 | } vlib/v/checker/tests/match_invalid_type.vv:4:2: error: match must be exhaustive (add match branches for: `int`, `string` or `else {}` at the end) @@ -16,6 +16,6 @@ vlib/v/checker/tests/match_invalid_type.vv:24:3: error: `Cat` doesn't implement 22 | a := Animal(Dog{}) 23 | match a { 24 | Cat { - | ~~~~~ + | ~~~ 25 | println('not cool either') - 26 | } + 26 | } \ No newline at end of file diff --git a/vlib/v/checker/tests/match_undefined_cond.out b/vlib/v/checker/tests/match_undefined_cond.out index df80738ce9..8fbaa176c8 100644 --- a/vlib/v/checker/tests/match_undefined_cond.out +++ b/vlib/v/checker/tests/match_undefined_cond.out @@ -9,13 +9,13 @@ vlib/v/checker/tests/match_undefined_cond.vv:5:3: error: cannot match `any_int` 3 | fn main() { 4 | res := match Asd { 5 | 1 { 'foo' } - | ~~~ + | ^ 6 | 2 { 'test' } 7 | else { '' } vlib/v/checker/tests/match_undefined_cond.vv:6:3: error: cannot match `any_int` with `void` condition 4 | res := match Asd { 5 | 1 { 'foo' } 6 | 2 { 'test' } - | ~~~ + | ^ 7 | else { '' } 8 | } diff --git a/vlib/v/doc/utils.v b/vlib/v/doc/utils.v index 2ede76c4e2..311a4f3d7f 100644 --- a/vlib/v/doc/utils.v +++ b/vlib/v/doc/utils.v @@ -96,7 +96,7 @@ pub fn (d Doc) stmt_name(stmt ast.Stmt) string { match stmt { ast.FnDecl, ast.StructDecl, ast.EnumDecl, ast.InterfaceDecl { return stmt.name } ast.TypeDecl { match stmt { - ast.SumTypeDecl, ast.FnTypeDecl, ast.AliasTypeDecl { return stmt.name } + ast.SumTypeDecl, ast.FnTypeDecl, ast.AliasTypeDecl, ast.UnionSumTypeDecl { return stmt.name } } } ast.ConstDecl { return '' } // leave it blank else { return '' } diff --git a/vlib/v/fmt/fmt.v b/vlib/v/fmt/fmt.v index afdea27cf4..90789972ec 100644 --- a/vlib/v/fmt/fmt.v +++ b/vlib/v/fmt/fmt.v @@ -564,6 +564,25 @@ pub fn (mut f Fmt) type_decl(node ast.TypeDecl) { } // f.write(sum_type_names.join(' | ')) } + ast.UnionSumTypeDecl { + if node.is_pub { + f.write('pub ') + } + f.write('__type $node.name = ') + mut sum_type_names := []string{} + for t in node.sub_types { + sum_type_names << f.table.type_to_str(t) + } + sum_type_names.sort() + for i, name in sum_type_names { + f.write(name) + if i < sum_type_names.len - 1 { + f.write(' | ') + } + f.wrap_long_line(2, true) + } + // f.write(sum_type_names.join(' | ')) + } } f.writeln('\n') } diff --git a/vlib/v/gen/auto_str_methods.v b/vlib/v/gen/auto_str_methods.v index 948e33ceb8..347939e422 100644 --- a/vlib/v/gen/auto_str_methods.v +++ b/vlib/v/gen/auto_str_methods.v @@ -76,6 +76,9 @@ fn (mut g Gen) gen_str_for_type_with_styp(typ table.Type, styp string) string { table.SumType { g.gen_str_for_sum_type(it, styp, str_fn_name) } + table.UnionSumType { + g.gen_str_for_union_sum_type(it, styp, str_fn_name) + } else { verror("could not generate string method $str_fn_name for type \'$styp\'") } @@ -524,3 +527,47 @@ fn (mut g Gen) gen_str_for_sum_type(info table.SumType, styp string, str_fn_name g.auto_str_funcs.writeln('\t}') g.auto_str_funcs.writeln('}') } + +fn (mut g Gen) gen_str_for_union_sum_type(info table.UnionSumType, styp string, str_fn_name string) { + mut gen_fn_names := map[string]string{} + for typ in info.variants { + sym := g.table.get_type_symbol(typ) + if !sym.has_method('str') { + field_styp := g.typ(typ) + field_fn_name := g.gen_str_for_type_with_styp(typ, field_styp) + gen_fn_names[field_styp] = field_fn_name + } + } + // _str() functions should have a single argument, the indenting ones take 2: + g.type_definitions.writeln('string ${str_fn_name}($styp x); // auto') + g.auto_str_funcs.writeln('string ${str_fn_name}($styp x) { return indent_${str_fn_name}(x, 0); }') + g.type_definitions.writeln('string indent_${str_fn_name}($styp x, int indent_count); // auto') + g.auto_str_funcs.writeln('string indent_${str_fn_name}($styp x, int indent_count) {') + mut clean_sum_type_v_type_name := styp.replace('__', '.') + if styp.ends_with('*') { + clean_sum_type_v_type_name = '&' + clean_sum_type_v_type_name.replace('*', '') + } + clean_sum_type_v_type_name = util.strip_main_name(clean_sum_type_v_type_name) + g.auto_str_funcs.writeln('\tswitch(x.typ) {') + for typ in info.variants { + mut value_fmt := '%.*s\\000' + if typ == table.string_type { + value_fmt = "\'$value_fmt\'" + } + typ_str := g.typ(typ) + mut func_name := if typ_str in gen_fn_names { gen_fn_names[typ_str] } else { g.gen_str_for_type_with_styp(typ, + typ_str) } + sym := g.table.get_type_symbol(typ) + if sym.kind == .struct_ { + func_name = 'indent_$func_name' + } + g.auto_str_funcs.write('\t\tcase $typ: return _STR("${clean_sum_type_v_type_name}($value_fmt)", 2, ${func_name}(*($typ_str*)x._$typ.idx()') + if sym.kind == .struct_ { + g.auto_str_funcs.write(', indent_count') + } + g.auto_str_funcs.writeln('));') + } + g.auto_str_funcs.writeln('\t\tdefault: return tos_lit("unknown sum type value");') + g.auto_str_funcs.writeln('\t}') + g.auto_str_funcs.writeln('}') +} diff --git a/vlib/v/gen/cgen.v b/vlib/v/gen/cgen.v index ec4bec24dc..761fec83e7 100644 --- a/vlib/v/gen/cgen.v +++ b/vlib/v/gen/cgen.v @@ -24,103 +24,104 @@ const ( ) struct Gen { - pref &pref.Preferences - module_built string + pref &pref.Preferences + module_built string mut: - table &table.Table - out strings.Builder - cheaders strings.Builder - includes strings.Builder // all C #includes required by V modules - typedefs strings.Builder - typedefs2 strings.Builder - type_definitions strings.Builder // typedefs, defines etc (everything that goes to the top of the file) - definitions strings.Builder // typedefs, defines etc (everything that goes to the top of the file) - inits map[string]strings.Builder // contents of `void _vinit(){}` - cleanups map[string]strings.Builder // contents of `void _vcleanup(){}` - gowrappers strings.Builder // all go callsite wrappers - stringliterals strings.Builder // all string literals (they depend on tos3() beeing defined - auto_str_funcs strings.Builder // function bodies of all auto generated _str funcs - comptime_defines strings.Builder // custom defines, given by -d/-define flags on the CLI - pcs_declarations strings.Builder // -prof profile counter declarations for each function - hotcode_definitions strings.Builder // -live declarations & functions - shared_types strings.Builder // shared/lock types - channel_definitions strings.Builder // channel related code - options_typedefs strings.Builder // Option typedefs - options strings.Builder // `Option_xxxx` types - json_forward_decls strings.Builder // json type forward decls - enum_typedefs strings.Builder // enum types - sql_buf strings.Builder // for writing exprs to args via `sqlite3_bind_int()` etc - file ast.File - fn_decl &ast.FnDecl // pointer to the FnDecl we are currently inside otherwise 0 - last_fn_c_name string - tmp_count int // counter for unique tmp vars (_tmp1, tmp2 etc) - tmp_count2 int // a separate tmp var counter for autofree fn calls - variadic_args map[string]int - is_c_call bool // e.g. `C.printf("v")` - is_assign_lhs bool // inside left part of assign expr (for array_set(), etc) - is_assign_rhs bool // inside right part of assign after `=` (val expr) - is_array_set bool - is_amp bool // for `&Foo{}` to merge PrefixExpr `&` and StructInit `Foo{}`; also for `&byte(0)` etc - is_sql bool // Inside `sql db{}` statement, generating sql instead of C (e.g. `and` instead of `&&` etc) - is_shared bool // for initialization of hidden mutex in `[rw]shared` literals - is_vlines_enabled bool // is it safe to generate #line directives when -g is passed - vlines_path string // set to the proper path for generating #line directives - optionals []string // to avoid duplicates TODO perf, use map - chan_pop_optionals []string // types for `x := <-ch or {...}` - shareds []int // types with hidden mutex for which decl has been emitted - inside_ternary int // ?: comma separated statements on a single line - inside_map_postfix bool // inside map++/-- postfix expr - inside_map_infix bool // inside map< fn counter name - is_builtin_mod bool - hotcode_fn_names []string + ternary_names map[string]string + ternary_level_names map[string][]string + stmt_path_pos []int // positions of each statement start, for inserting C statements before the current statement + skip_stmt_pos bool // for handling if expressions + autofree (since both prepend C statements) + right_is_opt bool + autofree bool + indent int + empty_line bool + is_test bool + assign_op token.Kind // *=, =, etc (for array_set) + defer_stmts []ast.DeferStmt + defer_ifdef string + defer_profile_code string + str_types []string // types that need automatic str() generation + threaded_fns []string // for generating unique wrapper types and fns for `go xxx()` + array_fn_definitions []string // array equality functions that have been defined + map_fn_definitions []string // map equality functions that have been defined + is_json_fn bool // inside json.encode() + json_types []string // to avoid json gen duplicates + pcs []ProfileCounterMeta // -prof profile counter fn_names => fn counter name + is_builtin_mod bool + hotcode_fn_names []string // cur_fn ast.FnDecl - cur_generic_type table.Type // `int`, `string`, etc in `foo()` - sql_i int - sql_stmt_name string - sql_side SqlExprSide // left or right, to distinguish idents in `name == name` - inside_vweb_tmpl bool - inside_return bool - inside_or_block bool - strs_to_free0 []string // strings.Builder + cur_generic_type table.Type // `int`, `string`, etc in `foo()` + sql_i int + sql_stmt_name string + sql_side SqlExprSide // left or right, to distinguish idents in `name == name` + inside_vweb_tmpl bool + inside_return bool + inside_or_block bool + strs_to_free0 []string // strings.Builder // strs_to_free []string // strings.Builder - inside_call bool - has_main bool - inside_const bool - comp_for_method string // $for method in T { - comptime_var_type_map map[string]table.Type - match_sumtype_exprs []ast.Expr - match_sumtype_syms []table.TypeSymbol + inside_call bool + has_main bool + inside_const bool + comp_for_method string // $for method in T { + comptime_var_type_map map[string]table.Type + match_sumtype_exprs []ast.Expr + match_sumtype_syms []table.TypeSymbol // tmp_arg_vars_to_free []string // autofree_pregen map[string]string // autofree_pregen_buf strings.Builder // autofree_tmp_vars []string // to avoid redefining the same tmp vars in a single function - called_fn_name string - cur_mod string - is_js_call bool // for handling a special type arg #1 `json.decode(User, ...)` + called_fn_name string + cur_mod string + is_js_call bool // for handling a special type arg #1 `json.decode(User, ...)` // nr_vars_to_free int // doing_autofree_tmp bool - inside_lambda bool + inside_lambda bool + prevent_sum_type_unwrapping_once bool // needed for assign new values to sum type } const ( @@ -1209,8 +1210,85 @@ fn (mut g Gen) for_in(it ast.ForInStmt) { } } +// use instead of expr() when you need to cast to union sum type (can add other casts also) +fn (mut g Gen) union_expr_with_cast(expr ast.Expr, got_type table.Type, expected_type table.Type) { + // cast to sum type + if expected_type != table.void_type { + expected_is_ptr := expected_type.is_ptr() + expected_deref_type := if expected_is_ptr { expected_type.deref() } else { expected_type } + got_is_ptr := got_type.is_ptr() + got_deref_type := if got_is_ptr { got_type.deref() } else { got_type } + if g.table.sumtype_has_variant(expected_deref_type, got_deref_type) { + exp_styp := g.typ(expected_type) + got_styp := g.typ(got_type) + got_idx := got_type.idx() + got_sym := g.table.get_type_symbol(got_type) + if expected_is_ptr && got_is_ptr { + exp_der_styp := g.typ(expected_deref_type) + g.write('/* union sum type cast 1 */ ($exp_styp) memdup(&($exp_der_styp){._$got_type = ') + g.expr(expr) + g.write(', .typ = $got_type /* $got_sym.name */}, sizeof($exp_der_styp))') + } else if expected_is_ptr { + exp_der_styp := g.typ(expected_deref_type) + g.write('/* union sum type cast 2 */ ($exp_styp) memdup(&($exp_der_styp){._$got_type = memdup(&($got_styp[]){') + g.expr(expr) + g.write('}, sizeof($got_styp)), .typ = $got_type /* $got_sym.name */}, sizeof($exp_der_styp))') + } else if got_is_ptr { + g.write('/* union sum type cast 3 */ ($exp_styp){._$got_idx = ') + g.expr(expr) + g.write(', .typ = $got_type /* $got_sym.name */}') + } else { + mut is_already_sum_type := false + scope := g.file.scope.innermost(expr.position().pos) + if expr is ast.Ident { + if v := scope.find_var(expr.name) { + if v.sum_type_cast != 0 { + is_already_sum_type = true + } + } + } else if expr is ast.SelectorExpr { + if _ := scope.find_struct_field(expr.expr_type, expr.field_name) { + is_already_sum_type = true + } + } + if is_already_sum_type { + // Don't create a new sum type wrapper if there is already one + g.prevent_sum_type_unwrapping_once = true + g.expr(expr) + } else { + g.write('/* union sum type cast 4 */ ($exp_styp){._$got_type = memdup(&($got_styp[]){') + g.expr(expr) + g.write('}, sizeof($got_styp)), .typ = $got_type /* $got_sym.name */}') + } + } + return + } + } + // Generic dereferencing logic + expected_sym := g.table.get_type_symbol(expected_type) + got_is_ptr := got_type.is_ptr() + expected_is_ptr := expected_type.is_ptr() + neither_void := table.voidptr_type !in [got_type, expected_type] + if got_is_ptr && !expected_is_ptr && neither_void && expected_sym.kind !in [.interface_, .placeholder] { + got_deref_type := got_type.deref() + deref_sym := g.table.get_type_symbol(got_deref_type) + deref_will_match := expected_type in [got_type, got_deref_type, deref_sym.parent_idx] + got_is_opt := got_type.has_flag(.optional) + if deref_will_match || got_is_opt { + g.write('*') + } + } + // no cast + g.expr(expr) +} + // use instead of expr() when you need to cast to sum type (can add other casts also) fn (mut g Gen) expr_with_cast(expr ast.Expr, got_type table.Type, expected_type table.Type) { + sym := g.table.get_type_symbol(expected_type) + if sym.kind == .union_sum_type { + g.union_expr_with_cast(expr, got_type, expected_type) + return + } // cast to sum type if expected_type != table.void_type { expected_is_ptr := expected_type.is_ptr() @@ -1725,6 +1803,9 @@ fn (mut g Gen) gen_assign_stmt(assign_stmt ast.AssignStmt) { } g.write('$styp ') } + if left is ast.Ident || left is ast.SelectorExpr { + g.prevent_sum_type_unwrapping_once = true + } g.expr(left) } if is_inside_ternary && is_decl { @@ -2132,7 +2213,7 @@ fn (mut g Gen) expr(node ast.Expr) { g.expr(node.arg) } g.write(')') - } else if sym.kind == .sum_type { + } else if sym.kind in [.sum_type, .union_sum_type] { g.expr_with_cast(node.expr, node.expr_type, node.typ) } else if sym.kind == .struct_ && !node.typ.is_ptr() && !(sym.info as table.Struct).is_typedef { styp := g.typ(node.typ) @@ -2374,6 +2455,8 @@ fn (mut g Gen) expr(node ast.Expr) { g.struct_init(node) } ast.SelectorExpr { + prevent_sum_type_unwrapping_once := g.prevent_sum_type_unwrapping_once + g.prevent_sum_type_unwrapping_once = false if node.name_type > 0 { g.type_name(node.name_type) return @@ -2394,6 +2477,21 @@ fn (mut g Gen) expr(node ast.Expr) { g.write(')') return } + mut sum_type_deref_field := '' + if field := g.table.struct_find_field(sym, node.field_name) { + field_sym := g.table.get_type_symbol(field.typ) + if field_sym.kind == .union_sum_type { + if !prevent_sum_type_unwrapping_once { + // check first if field is sum type because scope searching is expensive + scope := g.file.scope.innermost(node.pos.pos) + if field := scope.find_struct_field(node.expr_type, node.field_name) { + // union sum type deref + g.write('(*') + sum_type_deref_field = '_$field.sum_type_cast' + } + } + } + } g.expr(node.expr) // struct embedding if sym.kind == .struct_ { @@ -2419,6 +2517,9 @@ fn (mut g Gen) expr(node ast.Expr) { verror('cgen: SelectorExpr | expr_type: 0 | it.expr: `$node.expr` | field: `$node.field_name` | file: $g.file.path | line: $node.pos.line_nr') } g.write(c_name(node.field_name)) + if sum_type_deref_field != '' { + g.write('.$sum_type_deref_field)') + } } ast.Type { // match sum Type @@ -2870,6 +2971,7 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str // iterates through all types in sumtype branches for { is_last := j == node.branches.len - 1 + sym := g.table.get_type_symbol(node.cond_type) if branch.is_else || (node.is_expr && is_last) { if is_expr { // TODO too many branches. maybe separate ?: matches @@ -2891,9 +2993,8 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str g.write('if (') } g.write(cond_var) - sym := g.table.get_type_symbol(node.cond_type) // branch_sym := g.table.get_type_symbol(branch.typ) - if sym.kind == .sum_type { + if sym.kind in [.sum_type, .union_sum_type] { dot_or_ptr := if node.cond_type.is_ptr() { '->' } else { '.' } g.write(dot_or_ptr) g.write('typ == ') @@ -2909,7 +3010,7 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str } } // g.writeln('/* M sum_type=$node.is_sum_type is_expr=$node.is_expr exp_type=${g.typ(node.expected_type)}*/') - if !branch.is_else && !node.is_expr { + if sym.kind != .union_sum_type && !branch.is_else && !node.is_expr { // Use the nodes in the expr to generate `it` variable. type_expr := branch.exprs[sumtype_index] if type_expr !is ast.Type { @@ -3168,6 +3269,8 @@ fn (mut g Gen) select_expr(node ast.SelectExpr) { } fn (mut g Gen) ident(node ast.Ident) { + prevent_sum_type_unwrapping_once := g.prevent_sum_type_unwrapping_once + g.prevent_sum_type_unwrapping_once = false if node.name == 'lld' { return } @@ -3196,6 +3299,15 @@ fn (mut g Gen) ident(node ast.Ident) { g.write('${name}.val') return } + scope := g.file.scope.innermost(node.pos.pos) + if v := scope.find_var(node.name) { + if v.sum_type_cast != 0 { + if !prevent_sum_type_unwrapping_once { + g.write('(*${name}._$v.sum_type_cast)') + return + } + } + } } g.write(g.get_ternary_name(name)) } @@ -3214,11 +3326,22 @@ fn (mut g Gen) should_write_asterisk_due_to_match_sumtype(expr ast.Expr) bool { fn (mut g Gen) match_sumtype_has_no_struct_and_contains(node ast.Ident) bool { for i, expr in g.match_sumtype_exprs { if expr is ast.Ident && node.name == (expr as ast.Ident).name { - sumtype := g.match_sumtype_syms[i].info as table.SumType - for typ in sumtype.variants { - if g.table.get_type_symbol(typ).kind == .struct_ { - return false + match g.match_sumtype_syms[i].info as sumtype { + table.SumType { + for typ in sumtype.variants { + if g.table.get_type_symbol(typ).kind == .struct_ { + return false + } + } } + table.UnionSumType { + for typ in sumtype.variants { + if g.table.get_type_symbol(typ).kind == .struct_ { + return false + } + } + } + else {} } return true } @@ -4385,6 +4508,22 @@ fn (mut g Gen) write_types(types []table.TypeSymbol) { g.type_definitions.writeln('} $name;') g.type_definitions.writeln('') } + table.UnionSumType { + g.type_definitions.writeln('') + g.type_definitions.writeln('// Union sum type $name = ') + for variant in it.variants { + g.type_definitions.writeln('// | ${variant:4d} = ${g.typ(variant.idx()):-20s}') + } + g.type_definitions.writeln('typedef struct {') + g.type_definitions.writeln(' union {') + for variant in g.table.get_union_sum_type_variants(it) { + g.type_definitions.writeln(' ${g.typ(variant.to_ptr())} _$variant.idx();') + } + g.type_definitions.writeln(' };') + g.type_definitions.writeln(' int typ;') + g.type_definitions.writeln('} $name;') + g.type_definitions.writeln('') + } table.ArrayFixed { // .array_fixed { styp := util.no_dots(typ.name) @@ -5327,6 +5466,17 @@ fn (mut g Gen) as_cast(node ast.AsCast) { g.write(')') g.write(dot) g.write('typ, /*expected:*/$node.typ)') + } else if expr_type_sym.kind == .union_sum_type { + dot := if node.expr_type.is_ptr() { '->' } else { '.' } + g.write('/* as */ *($styp*)__as_cast((') + g.expr(node.expr) + g.write(')') + g.write(dot) + g.write('_$node.typ.idx(), (') + g.expr(node.expr) + g.write(')') + g.write(dot) + g.write('typ, /*expected:*/$node.typ)') } } @@ -5348,7 +5498,7 @@ fn (mut g Gen) is_expr(node ast.InfixExpr) { sub_sym := g.table.get_type_symbol(sub_type.typ) g.write('_${c_name(sym.name)}_${c_name(sub_sym.name)}_index') return - } else if sym.kind == .sum_type { + } else if sym.kind in [.sum_type, .union_sum_type] { g.write('typ $eq ') } g.expr(node.right) diff --git a/vlib/v/gen/js/js.v b/vlib/v/gen/js/js.v index 57d6df7f7f..1d6f765cd3 100644 --- a/vlib/v/gen/js/js.v +++ b/vlib/v/gen/js/js.v @@ -261,6 +261,10 @@ pub fn (mut g JsGen) typ(t table.Type) string { // TODO: Implement sumtypes styp = 'sym_type' } + .union_sum_type { + // TODO: Implement sumtypes + styp = 'union_sym_type' + } .alias { // TODO: Implement aliases styp = 'alias' diff --git a/vlib/v/parser/if_match.v b/vlib/v/parser/if_match.v index 6012451edf..95e242aae7 100644 --- a/vlib/v/parser/if_match.v +++ b/vlib/v/parser/if_match.v @@ -178,6 +178,11 @@ fn (mut p Parser) match_expr() ast.MatchExpr { match_first_pos := p.tok.position() p.inside_match = true p.check(.key_match) + mut is_union_match := false + if p.tok.kind == .key_union { + p.check(.key_union) + is_union_match = true + } is_mut := p.tok.kind == .key_mut mut is_sum_type := false if is_mut { @@ -230,7 +235,7 @@ fn (mut p Parser) match_expr() ast.MatchExpr { types << parsed_type exprs << ast.Type{ typ: parsed_type - pos: p.tok.position() + pos: p.prev_tok.position() } if p.tok.kind != .comma { break @@ -263,23 +268,25 @@ fn (mut p Parser) match_expr() ast.MatchExpr { } }) } - p.scope.register('it', ast.Var{ - name: 'it' - typ: it_typ.to_ptr() - pos: cond_pos - is_used: true - is_mut: is_mut - }) - if var_name.len > 0 { - // Register shadow variable or `as` variable with actual type - p.scope.register(var_name, ast.Var{ - name: var_name + if !is_union_match { + p.scope.register('it', ast.Var{ + name: 'it' typ: it_typ.to_ptr() pos: cond_pos is_used: true - is_changed: true // TODO mut unchanged warning hack, remove is_mut: is_mut }) + if var_name.len > 0 { + // Register shadow variable or `as` variable with actual type + p.scope.register(var_name, ast.Var{ + name: var_name + typ: it_typ.to_ptr() + pos: cond_pos + is_used: true + is_changed: true // TODO mut unchanged warning hack, remove + is_mut: is_mut + }) + } } is_sum_type = true } else { diff --git a/vlib/v/parser/parser.v b/vlib/v/parser/parser.v index 608dbc96f5..5a9105182c 100644 --- a/vlib/v/parser/parser.v +++ b/vlib/v/parser/parser.v @@ -434,6 +434,9 @@ pub fn (mut p Parser) top_stmt() ast.Stmt { .key_type { return p.type_decl() } + .key___type { + return p.union_sum_type_decl() + } else { p.error('wrong pub keyword usage') return ast.Stmt{} @@ -476,6 +479,9 @@ pub fn (mut p Parser) top_stmt() ast.Stmt { .key_type { return p.type_decl() } + .key___type { + return p.union_sum_type_decl() + } .key_enum { return p.enum_decl() } @@ -1855,6 +1861,58 @@ $pubfn (mut e $enum_name) toggle(flag $enum_name) { unsafe{ *e = int(*e) ^ ( } } +fn (mut p Parser) union_sum_type_decl() ast.TypeDecl { + start_pos := p.tok.position() + is_pub := p.tok.kind == .key_pub + if is_pub { + p.next() + } + p.check(.key___type) + end_pos := p.tok.position() + decl_pos := start_pos.extend(end_pos) + name := p.check_name() + if name.len == 1 && name[0].is_capital() { + p.error_with_pos('single letter capital names are reserved for generic template types.', + decl_pos) + } + p.check(.assign) + mut sum_variants := []table.Type{} + first_type := p.parse_type() // need to parse the first type before we can check if it's `type A = X | Y` + if p.tok.kind == .pipe { + p.next() + sum_variants << first_type + // type SumType = A | B | c + for { + variant_type := p.parse_type() + sum_variants << variant_type + if p.tok.kind != .pipe { + break + } + p.check(.pipe) + } + prepend_mod_name := p.prepend_mod(name) + p.table.register_type_symbol(table.TypeSymbol{ + kind: .union_sum_type + name: prepend_mod_name + source_name: prepend_mod_name + mod: p.mod + info: table.UnionSumType{ + variants: sum_variants + } + is_public: is_pub + }) + return ast.UnionSumTypeDecl{ + name: name + is_pub: is_pub + sub_types: sum_variants + pos: decl_pos + } + } + // just for this implementation + p.error_with_pos('wrong union sum type declaration', decl_pos) + return ast.TypeDecl{} +} + fn (mut p Parser) type_decl() ast.TypeDecl { start_pos := p.tok.position() is_pub := p.tok.kind == .key_pub diff --git a/vlib/v/table/atypes.v b/vlib/v/table/atypes.v index 5b2703d142..f3a8417578 100644 --- a/vlib/v/table/atypes.v +++ b/vlib/v/table/atypes.v @@ -16,7 +16,7 @@ import strings pub type Type = int pub type TypeInfo = Aggregate | Alias | Array | ArrayFixed | Chan | Enum | FnType | GenericStructInst | - Interface | Map | MultiReturn | Struct | SumType + Interface | Map | MultiReturn | Struct | SumType | UnionSumType pub enum Language { v @@ -362,6 +362,7 @@ pub enum Kind { generic_struct_inst multi_return sum_type + union_sum_type alias enum_ function @@ -676,6 +677,7 @@ pub fn (k Kind) str() string { .chan { 'chan' } .multi_return { 'multi_return' } .sum_type { 'sum_type' } + .union_sum_type { 'union_sum_type' } .alias { 'alias' } .enum_ { 'enum' } .any { 'any' } @@ -803,6 +805,23 @@ pub: variants []Type } +pub struct UnionSumType { +pub: + variants []Type +} + +pub fn (table &Table) get_union_sum_type_variants(sum_type UnionSumType) []Type { + mut variants := []Type{} + for variant in sum_type.variants { + sym := table.get_type_symbol(variant) + if sym.info is UnionSumType as sym_info { + variants << table.get_union_sum_type_variants(sym_info) + } + variants << variant + } + return variants +} + pub fn (table &Table) type_to_str(t Type) string { sym := table.get_type_symbol(t) mut res := sym.source_name diff --git a/vlib/v/table/table.v b/vlib/v/table/table.v index 19e0f24d7e..5461e6a233 100644 --- a/vlib/v/table/table.v +++ b/vlib/v/table/table.v @@ -732,6 +732,16 @@ pub fn (table &Table) sumtype_has_variant(parent Type, variant Type) bool { return true } } + } else if parent_sym.kind == .union_sum_type { + parent_info := parent_sym.info as UnionSumType + for v in parent_info.variants { + if v.idx() == variant.idx() { + return true + } + if table.sumtype_has_variant(v, variant) { + return true + } + } } return false } diff --git a/vlib/v/tests/union_sum_type_test.v b/vlib/v/tests/union_sum_type_test.v new file mode 100644 index 0000000000..b9fbf16ed0 --- /dev/null +++ b/vlib/v/tests/union_sum_type_test.v @@ -0,0 +1,453 @@ +__type Expr = IfExpr | IntegerLiteral +__type Stmt = FnDecl | StructDecl +__type Node = Expr | Stmt + +struct FnDecl { + pos int +} + +struct StructDecl { + pos int +} + + +struct IfExpr { + pos int +} + +struct IntegerLiteral { + val string +} + +fn handle(e Expr) string { + is_literal := e is IntegerLiteral + assert is_literal + assert !(e !is IntegerLiteral) + if e is IntegerLiteral { + println('int') + } + match union e { + IntegerLiteral { + assert e.val == '12' + // assert e.val == '12' // TODO + return 'int' + } + IfExpr { + return 'if' + } + } + return '' +} + +fn test_expr() { + expr := IntegerLiteral{ + val: '12' + } + assert handle(expr) == 'int' + // assert expr is IntegerLiteral // TODO +} + +fn test_assignment_and_push() { + mut expr1 := Expr{} + mut arr1 := []Expr{} + expr := IntegerLiteral{ + val: '111' + } + arr1 << expr + match union arr1[0] { + IntegerLiteral { + arr1 << arr1[0] + // should ref/dereference on assignent be made automatic? + // currently it is done for return stmt and fn args + expr1 = arr1[0] + } + else {} + } +} + +// Test moving structs between master/sub arrays +__type Master = Sub1 | Sub2 + +struct Sub1 { +mut: + val int + name string +} + +struct Sub2 { + name string + val int +} + +fn test_converting_down() { + mut out := []Master{} + out << Sub1{ + val: 1 + name: 'one' + } + out << Sub2{ + val: 2 + name: 'two' + } + out << Sub2{ + val: 3 + name: 'three' + } + mut res := []Sub2{cap: out.len} + for d in out { + match union d { + Sub2 { res << d } + else {} + } + } + assert res[0].val == 2 + assert res[0].name == 'two' + assert res[1].val == 3 + assert res[1].name == 'three' +} + +fn test_nested_sumtype() { + mut a := Node{} + mut b := Node{} + a = StructDecl{pos: 1} + b = IfExpr{pos: 1} + c := Node(Expr(IfExpr{pos:1})) + if c is Expr { + if c is IfExpr { + assert true + } + else { + assert false + } + } + else { + assert false + } +} + +__type Abc = int | string + +fn test_string_cast_to_sumtype() { + a := Abc('test') + match union a { + int { + assert false + } + string { + assert true + } + } +} + +fn test_int_cast_to_sumtype() { + // literal + a := Abc(111) + match union a { + int { + assert true + } + string { + assert false + } + } + // var + i := 111 + b := Abc(i) + match union b { + int { + assert true + } + string { + assert false + } + } +} + +// TODO: change definition once types other than int and f64 (int, f64, etc) are supported in sumtype +__type Number = int | f64 + +fn is_gt_simple(val string, dst Number) bool { + match union dst { + int { + return val.int() > dst + } + f64 { + return dst < val.f64() + } + } +} + +fn is_gt_nested(val string, dst Number) bool { + dst2 := dst + match union dst { + int { + match union dst2 { + int { + return val.int() > dst + } + // this branch should never been hit + else { + return val.int() < dst + } + } + } + f64 { + match union dst2 { + f64 { + return dst < val.f64() + } + // this branch should never been hit + else { + return dst > val.f64() + } + } + } + } +} + +fn concat(val string, dst Number) string { + match union dst { + int { + mut res := val + '(int)' + res += dst.str() + return res + } + f64 { + mut res := val + '(float)' + res += dst.str() + return res + } + } +} + +fn get_sum(val string, dst Number) f64 { + match union dst { + int { + mut res := val.int() + res += dst + return res + } + f64 { + mut res := val.f64() + res += dst + return res + } + } +} + +__type Bar = string | Test +__type Xyz = int | string + +struct Test { + x string + xyz Xyz +} + +struct Foo { + y Bar +} + +fn test_nested_selector_smartcast() { + f := Foo{ + y: Bar(Test{ + x: 'Hi' + xyz: Xyz(5) + }) + } + + if f.y is Test { + z := f.y.x + assert f.y.x == 'Hi' + assert z == 'Hi' + if f.y.xyz is int { + assert f.y.xyz == 5 + } + } +} + +fn test_as_cast() { + f := Foo{ + y: Bar('test') + } + y := f.y as string + assert y == 'test' +} + +fn test_assignment() { + y := 5 + mut x := Xyz(y) + x = 'test' + + if x is string { + assert x == 'test' + } +} + +__type Inner = int | string +struct InnerStruct { +mut: + x Inner +} +__type Outer = string | InnerStruct + +fn test_nested_if_is() { + mut b := Outer(InnerStruct{Inner(0)}) + if b is InnerStruct { + if b.x is int { + println(b.x) + } + } +} + +fn test_casted_sum_type_selector_reassign() { + mut b := InnerStruct{Inner(0)} + if b.x is int { + assert typeof(b.x) == 'int' + b.x = 'test' + assert typeof(b.x) == 'string' + } + assert typeof(b.x) == 'Inner' +} + +fn test_casted_sum_type_ident_reassign() { + mut x := Inner(0) + if x is int { + assert typeof(x) == 'int' + x = 'test' + assert typeof(x) == 'string' + } + assert typeof(x) == 'Inner' +} + +__type Expr2 = int | string + +fn test_match_with_reassign_casted_type() { + mut e := Expr2(0) + match union mut e { + int { + e = int(5) + assert e == 5 + } + else {} + } +} + +fn test_if_is_with_reassign_casted_type() { + mut e := Expr2(0) + if e is int { + e = int(5) + assert e == 5 + } +} + +struct Expr2Wrapper { +mut: + expr Expr2 +} + +fn test_change_type_if_is_selector() { + mut e := Expr2Wrapper{Expr2(0)} + if e.expr is int { + e.expr = 'str' + assert e.expr.len == 3 + } + assert e.expr is string +} + +fn test_change_type_if_is() { + mut e := Expr2(0) + if e is int { + e = 'str' + assert e.len == 3 + } + assert e is string +} + +fn test_change_type_match() { + mut e := Expr2(0) + match union mut e { + int { + e = 'str' + assert e.len == 3 + } + else {} + } + assert e is string +} + +__type Expr3 = CallExpr | string + +struct CallExpr { +mut: + is_expr bool +} + +fn test_assign_sum_type_casted_field() { + mut e := Expr3(CallExpr{}) + if e is CallExpr { + e.is_expr = true + assert e.is_expr + } +} + +__type Expr4 = CallExpr2 | CTempVarExpr +struct Expr4Wrapper { +mut: + expr Expr4 +} +struct CallExpr2 { + y int + x string +} + +struct CTempVarExpr { + x string +} + +fn gen(_ Expr4) CTempVarExpr { + return CTempVarExpr{} +} + +fn test_reassign_from_function_with_parameter() { + mut f := Expr4(CallExpr2{}) + if f is CallExpr2 { + f = gen(f) + } +} + +fn test_reassign_from_function_with_parameter_selector() { + mut f := Expr4Wrapper{Expr4(CallExpr2{})} + if f.expr is CallExpr2 { + f.expr = gen(f.expr) + } +} + +fn test_match_multi_branch() { + f := Expr4(CTempVarExpr{'ctemp'}) + mut y := '' + match union f { + CallExpr2, CTempVarExpr { + assert typeof(f) == 'Expr4' + } + } +} + +fn test_sum_type_match() { + // TODO: Remove these casts + assert is_gt_simple('3', int(2)) + assert !is_gt_simple('3', int(5)) + assert is_gt_simple('3', f64(1.2)) + assert !is_gt_simple('3', f64(3.5)) + assert is_gt_nested('3', int(2)) + assert !is_gt_nested('3', int(5)) + assert is_gt_nested('3', f64(1.2)) + assert !is_gt_nested('3', f64(3.5)) + assert concat('3', int(2)) == '3(int)2' + assert concat('3', int(5)) == '3(int)5' + assert concat('3', f64(1.2)) == '3(float)1.2' + assert concat('3', f64(3.5)) == '3(float)3.5' + assert get_sum('3', int(2)) == 5.0 + assert get_sum('3', int(5)) == 8.0 + assert get_sum('3', f64(1.2)) == 4.2 + assert get_sum('3', f64(3.5)) == 6.5 +} diff --git a/vlib/v/token/token.v b/vlib/v/token/token.v index 02d5ab8c54..ded53d0843 100644 --- a/vlib/v/token/token.v +++ b/vlib/v/token/token.v @@ -121,6 +121,7 @@ pub enum Kind { key_struct key_true key_type + key___type // __type key_typeof key_orelse key_union @@ -264,6 +265,7 @@ fn build_token_str() []string { s[Kind.key_lock] = 'lock' s[Kind.key_rlock] = 'rlock' s[Kind.key_type] = 'type' + s[Kind.key___type] = '__type' s[Kind.key_for] = 'for' s[Kind.key_fn] = 'fn' s[Kind.key_true] = 'true'