From f1469a87613dc02c59e2786fdebe26878aec2872 Mon Sep 17 00:00:00 2001 From: spaceface Date: Tue, 9 Mar 2021 18:16:18 +0100 Subject: [PATCH] checker: allow accessing fields common to all sumtype members (#9201) --- vlib/v/ast/ast.v | 1 + vlib/v/checker/checker.v | 4 +- .../sum_type_common_fields_alias_error.out | 20 +++ .../sum_type_common_fields_alias_error.vv | 38 +++++ .../tests/sum_type_common_fields_error.out | 6 + .../tests/sum_type_common_fields_error.vv | 54 +++++++ vlib/v/gen/c/auto_str_methods.v | 2 +- vlib/v/gen/c/cgen.v | 134 +++++++++++------- vlib/v/gen/c/fn.v | 2 +- vlib/v/parser/parser.v | 3 +- vlib/v/table/table.v | 72 ++++++++-- vlib/v/table/types.v | 15 +- vlib/v/tests/sum_type_common_fields_test.v | 48 +++++++ 13 files changed, 330 insertions(+), 69 deletions(-) create mode 100644 vlib/v/checker/tests/sum_type_common_fields_alias_error.out create mode 100644 vlib/v/checker/tests/sum_type_common_fields_alias_error.vv create mode 100644 vlib/v/checker/tests/sum_type_common_fields_error.out create mode 100644 vlib/v/checker/tests/sum_type_common_fields_error.vv create mode 100644 vlib/v/tests/sum_type_common_fields_test.v diff --git a/vlib/v/ast/ast.v b/vlib/v/ast/ast.v index 271e231e38..8c74d1615a 100644 --- a/vlib/v/ast/ast.v +++ b/vlib/v/ast/ast.v @@ -879,6 +879,7 @@ pub: is_pub bool pos token.Position comments []Comment + typ table.Type pub mut: variants []SumTypeVariant } diff --git a/vlib/v/checker/checker.v b/vlib/v/checker/checker.v index 546162f7d7..839693077b 100644 --- a/vlib/v/checker/checker.v +++ b/vlib/v/checker/checker.v @@ -2293,7 +2293,7 @@ pub fn (mut c Checker) selector_expr(mut selector_expr ast.SelectorExpr) table.T c.error('ambiguous field `$field_name`', selector_expr.pos) } } - if sym.kind == .aggregate { + if sym.kind in [.aggregate, .sum_type] { unknown_field_msg = err } } @@ -2321,7 +2321,7 @@ pub fn (mut c Checker) selector_expr(mut selector_expr ast.SelectorExpr) table.T selector_expr.typ = field.typ return field.typ } - if sym.kind !in [.struct_, .aggregate, .interface_] { + if sym.kind !in [.struct_, .aggregate, .interface_, .sum_type] { if sym.kind != .placeholder { c.error('`$sym.name` is not a struct', selector_expr.pos) } diff --git a/vlib/v/checker/tests/sum_type_common_fields_alias_error.out b/vlib/v/checker/tests/sum_type_common_fields_alias_error.out new file mode 100644 index 0000000000..20c3b008b0 --- /dev/null +++ b/vlib/v/checker/tests/sum_type_common_fields_alias_error.out @@ -0,0 +1,20 @@ +vlib/v/checker/tests/sum_type_common_fields_alias_error.vv:35:14: error: field `name` does not exist or have the same type in all sumtype variants + 33 | } + 34 | println(m) + 35 | assert m[0].name == 'abc' + | ~~~~ + 36 | assert m[1].name == 'def' + 37 | assert m[2].name == 'xyz' +vlib/v/checker/tests/sum_type_common_fields_alias_error.vv:36:14: error: field `name` does not exist or have the same type in all sumtype variants + 34 | println(m) + 35 | assert m[0].name == 'abc' + 36 | assert m[1].name == 'def' + | ~~~~ + 37 | assert m[2].name == 'xyz' + 38 | } +vlib/v/checker/tests/sum_type_common_fields_alias_error.vv:37:14: error: field `name` does not exist or have the same type in all sumtype variants + 35 | assert m[0].name == 'abc' + 36 | assert m[1].name == 'def' + 37 | assert m[2].name == 'xyz' + | ~~~~ + 38 | } diff --git a/vlib/v/checker/tests/sum_type_common_fields_alias_error.vv b/vlib/v/checker/tests/sum_type_common_fields_alias_error.vv new file mode 100644 index 0000000000..9a296b8f77 --- /dev/null +++ b/vlib/v/checker/tests/sum_type_common_fields_alias_error.vv @@ -0,0 +1,38 @@ +type Main = Sub1 | Sub2 | Sub3 + +// NB: the subtypes will have a common `name` field, of the same `string` +// type, except Sub3, which has `name` of type AliasedString. + +type AliasedString = string + +struct Sub1 { +mut: + name string +} + +struct Sub2 { +mut: + name string +} + +struct Sub3 { +mut: + name AliasedString +} + +fn main() { + mut m := []Main{} + m << Sub1{ + name: 'abc' + } + m << Sub2{ + name: 'def' + } + m << Sub3{ + name: 'xyz' + } + println(m) + assert m[0].name == 'abc' + assert m[1].name == 'def' + assert m[2].name == 'xyz' +} diff --git a/vlib/v/checker/tests/sum_type_common_fields_error.out b/vlib/v/checker/tests/sum_type_common_fields_error.out new file mode 100644 index 0000000000..9dce91ab09 --- /dev/null +++ b/vlib/v/checker/tests/sum_type_common_fields_error.out @@ -0,0 +1,6 @@ +vlib/v/checker/tests/sum_type_common_fields_error.vv:53:14: error: field `val` does not exist or have the same type in all sumtype variants + 51 | assert m[2].name == '64bit integer' + 52 | assert m[3].name == 'string' + 53 | assert m[0].val == 123 + | ~~~ + 54 | } diff --git a/vlib/v/checker/tests/sum_type_common_fields_error.vv b/vlib/v/checker/tests/sum_type_common_fields_error.vv new file mode 100644 index 0000000000..2da560a941 --- /dev/null +++ b/vlib/v/checker/tests/sum_type_common_fields_error.vv @@ -0,0 +1,54 @@ +type Main = Sub1 | Sub2 | Sub3 | Sub4 + +// NB: all subtypes have a common name field, of the same `string` type +// but they also have a field `val` that is of a different type in the +// different subtypes => accessing `m[0].name` is fine, but *not* `m[0].val` +struct Sub1 { +mut: + val int + name string +} + +struct Sub2 { +mut: + val f32 + name string +} + +struct Sub3 { +mut: + val i64 + name string +} + +struct Sub4 { +mut: + val string + name string +} + +fn main() { + mut m := []Main{} + m << Sub1{ + val: 123 + name: 'integer' + } + m << Sub2{ + val: 3.14 + name: 'float' + } + m << Sub3{ + val: 9_876_543_210 + name: '64bit integer' + } + m << Sub4{ + val: 'abcd' + name: 'string' + } + println(m) + assert m[0].name == 'integer' + assert m[1].name == 'float' + assert m[2].name == '64bit integer' + assert m[3].name == 'string' + assert m[0].val == 123 +} diff --git a/vlib/v/gen/c/auto_str_methods.v b/vlib/v/gen/c/auto_str_methods.v index cb7dce4ae6..410b54091f 100644 --- a/vlib/v/gen/c/auto_str_methods.v +++ b/vlib/v/gen/c/auto_str_methods.v @@ -677,7 +677,7 @@ fn (mut g Gen) gen_str_for_union_sum_type(info table.SumType, styp string, str_f clean_sum_type_v_type_name = '&' + clean_sum_type_v_type_name.replace('*', '') } clean_sum_type_v_type_name = util.strip_main_name(clean_sum_type_v_type_name) - g.auto_str_funcs.writeln('\tswitch(x.typ) {') + g.auto_str_funcs.writeln('\tswitch(x._typ) {') for typ in info.variants { mut value_fmt := '%.*s\\000' if typ == table.string_type { diff --git a/vlib/v/gen/c/cgen.v b/vlib/v/gen/c/cgen.v index 63ad2e53f6..97188053aa 100644 --- a/vlib/v/gen/c/cgen.v +++ b/vlib/v/gen/c/cgen.v @@ -94,15 +94,16 @@ mut: defer_stmts []ast.DeferStmt defer_ifdef string defer_profile_code string - str_types []string // types that need automatic str() generation - threaded_fns []string // for generating unique wrapper types and fns for `go xxx()` - waiter_fns []string // functions that wait for `go xxx()` to finish - array_fn_definitions []string // array equality functions that have been defined - map_fn_definitions []string // map equality functions that have been defined - struct_fn_definitions []string // struct equality functions that have been defined - alias_fn_definitions []string // alias equality functions that have been defined - auto_fn_definitions []string // auto generated functions defination list - anon_fn_definitions []string // anon generated functions defination list + str_types []string // types that need automatic str() generation + threaded_fns []string // for generating unique wrapper types and fns for `go xxx()` + waiter_fns []string // functions that wait for `go xxx()` to finish + array_fn_definitions []string // array equality functions that have been defined + map_fn_definitions []string // map equality functions that have been defined + struct_fn_definitions []string // struct equality functions that have been defined + alias_fn_definitions []string // alias equality functions that have been defined + auto_fn_definitions []string // auto generated functions defination list + anon_fn_definitions []string // anon generated functions defination list + sumtype_definitions map[int]bool // `_TypeA_to_sumtype_TypeB()` fns that have been generated is_json_fn bool // inside json.encode() json_types []string // to avoid json gen duplicates pcs []ProfileCounterMeta // -prof profile counter fn_names => fn counter name @@ -1553,6 +1554,33 @@ fn (mut g Gen) for_in_stmt(node ast.ForInStmt) { } } +fn (mut g Gen) write_sumtype_casting_fn(got_ table.Type, exp_ table.Type) { + got, exp := got_.idx(), exp_.idx() + i := got | (exp << 16) + if got == exp || g.sumtype_definitions[i] { + return + } + g.sumtype_definitions[i] = true + got_sym := g.table.get_type_symbol(got) + exp_sym := g.table.get_type_symbol(exp) + mut sb := strings.new_builder(128) + got_cname, exp_cname := got_sym.cname, exp_sym.cname + sb.writeln('static inline $exp_cname ${got_cname}_to_sumtype_${exp_cname}($got_cname* x) {') + sb.writeln('\t$got_cname* ptr = memdup(x, sizeof($got_cname));') + sb.write_string('\treturn ($exp_cname){ ._$got_cname = ptr, ._typ = ${g.type_sidx(got)}') + for field in (exp_sym.info as table.SumType).fields { + field_styp := g.typ(field.typ) + if got_sym.kind in [.sum_type, .interface_] { + // the field is already a wrapped pointer; we shouldn't wrap it once again + sb.write_string(', .$field.name = ptr->$field.name') + } else { + sb.write_string(', .$field.name = ($field_styp*)((char*)ptr + __offsetof($got_cname, $field.name))') + } + } + sb.writeln('};\n}') + g.auto_fn_definitions << sb.str() +} + // use instead of expr() when you need to cast to a different type fn (mut g Gen) expr_with_cast(expr ast.Expr, got_type_raw table.Type, expected_type table.Type) { got_type := g.table.mktyp(got_type_raw) @@ -1590,45 +1618,39 @@ fn (mut g Gen) expr_with_cast(expr ast.Expr, got_type_raw table.Type, expected_t expected_deref_type := if expected_is_ptr { expected_type.deref() } else { expected_type } got_deref_type := if got_is_ptr { got_type.deref() } else { got_type } if g.table.sumtype_has_variant(expected_deref_type, got_deref_type) { - // got_idx := got_type.idx() - got_sidx := g.type_sidx(got_type) - // TODO: do we need 1-3? - if expected_is_ptr && got_is_ptr { - exp_der_styp := g.typ(expected_deref_type) - g.write('/* sum type cast 1 */ ($exp_styp) memdup(&($exp_der_styp){._$got_sym.cname = ') - g.expr(expr) - g.write(', .typ = $got_sidx /* $got_sym.name */}, sizeof($exp_der_styp))') - } else if expected_is_ptr { - exp_der_styp := g.typ(expected_deref_type) - g.write('/* sum type cast 2 */ ($exp_styp) memdup(&($exp_der_styp){._$got_sym.cname = memdup(&($got_styp[]){') - g.expr(expr) - g.write('}, sizeof($got_styp)), .typ = $got_sidx /* $got_sym.name */}, sizeof($exp_der_styp))') - } else if got_is_ptr { - g.write('/* sum type cast 3 */ ($exp_styp){._$got_sym.cname = ') - g.expr(expr) - g.write(', .typ = $got_sidx /* $got_sym.name */}') - } else { - mut is_already_sum_type := false - scope := g.file.scope.innermost(expr.position().pos) - if expr is ast.Ident { - if v := scope.find_var(expr.name) { - if v.sum_type_casts.len > 0 { - is_already_sum_type = true - } - } - } else if expr is ast.SelectorExpr { - if _ := scope.find_struct_field(expr.expr_type, expr.field_name) { + mut is_already_sum_type := false + scope := g.file.scope.innermost(expr.position().pos) + if expr is ast.Ident { + if v := scope.find_var(expr.name) { + if v.sum_type_casts.len > 0 { is_already_sum_type = true } } - if is_already_sum_type { - // Don't create a new sum type wrapper if there is already one - g.prevent_sum_type_unwrapping_once = true + } else if expr is ast.SelectorExpr { + if _ := scope.find_struct_field(expr.expr_type, expr.field_name) { + is_already_sum_type = true + } + } + if is_already_sum_type { + // Don't create a new sum type wrapper if there is already one + g.prevent_sum_type_unwrapping_once = true + g.expr(expr) + } else { + g.write_sumtype_casting_fn(got_type, expected_type) + if expected_is_ptr { + g.write('memdup(&($exp_sym.cname[]){') + } + g.write('${got_sym.cname}_to_sumtype_${exp_sym.cname}(') + if !got_is_ptr { + g.write('(&(($got_styp[]){') g.expr(expr) + g.write('}[0])))') } else { - g.write('/* sum type cast 4 */ ($exp_styp){._$got_sym.cname = memdup(&($got_styp[]){') g.expr(expr) - g.write('}, sizeof($got_styp)), .typ = $got_sidx /* $got_sym.name */}') + g.write(')') + } + if expected_is_ptr { + g.write('}, sizeof($exp_sym.cname))') } } return @@ -2707,7 +2729,7 @@ fn (mut g Gen) expr(node ast.Expr) { g.concat_expr(node) } ast.CTempVar { - // g.write('/*ctmp .orig: $node.orig.str() , .typ: $node.typ, .is_ptr: $node.is_ptr */ ') + // g.write('/*ctmp .orig: $node.orig.str() , ._typ: $node.typ, .is_ptr: $node.is_ptr */ ') g.write(node.name) } ast.EnumVal { @@ -2923,7 +2945,7 @@ fn (mut g Gen) typeof_expr(node ast.TypeOf) { // because the subtype of the expression may change: g.write('tos3( /* $sym.name */ v_typeof_sumtype_${sym.cname}( (') g.expr(node.expr) - g.write(').typ ))') + g.write(')._typ ))') } else if sym.kind == .array_fixed { fixed_info := sym.info as table.ArrayFixed typ_name := g.table.get_type_name(fixed_info.elem_type) @@ -2956,7 +2978,7 @@ fn (mut g Gen) selector_expr(node ast.SelectorExpr) { opt_base_typ := g.base_type(node.expr_type) g.writeln('(*($opt_base_typ*)') } - if sym.kind == .interface_ { + if sym.kind in [.interface_, .sum_type] { g.write('(*(') } if sym.kind == .array_fixed { @@ -3037,7 +3059,7 @@ fn (mut g Gen) selector_expr(node ast.SelectorExpr) { if sum_type_deref_field != '' { g.write('$sum_type_dot$sum_type_deref_field)') } - if sym.kind == .interface_ { + if sym.kind in [.interface_, .sum_type] { g.write('))') } } @@ -3678,7 +3700,7 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str g.write(cond_var) dot_or_ptr := if node.cond_type.is_ptr() { '->' } else { '.' } if sym.kind == .sum_type { - g.write('${dot_or_ptr}typ == ') + g.write('${dot_or_ptr}_typ == ') g.expr(branch.exprs[sumtype_index]) } else if sym.kind == .interface_ { typ := branch.exprs[sumtype_index] as ast.Type @@ -5074,13 +5096,19 @@ fn (mut g Gen) write_types(types []table.TypeSymbol) { g.type_definitions.writeln('// | ${variant:4d} = ${g.typ(variant.idx()):-20s}') } g.type_definitions.writeln('struct $name {') - g.type_definitions.writeln(' union {') + g.type_definitions.writeln('\tunion {') for variant in typ.info.variants { variant_sym := g.table.get_type_symbol(variant) - g.type_definitions.writeln(' ${g.typ(variant.to_ptr())} _$variant_sym.cname;') + g.type_definitions.writeln('\t\t${g.typ(variant.to_ptr())} _$variant_sym.cname;') + } + g.type_definitions.writeln('\t};') + g.type_definitions.writeln('\tint _typ;') + if typ.info.fields.len > 0 { + g.writeln('\t// pointers to common sumtype fields') + for field in typ.info.fields { + g.type_definitions.writeln('\t${g.typ(field.typ.to_ptr())} $field.name;') + } } - g.type_definitions.writeln(' };') - g.type_definitions.writeln(' int typ;') g.type_definitions.writeln('};') g.type_definitions.writeln('') } @@ -5691,7 +5719,7 @@ fn (mut g Gen) as_cast(node ast.AsCast) { // g.write('typ, /*expected:*/$node.typ)') sidx := g.type_sidx(node.typ) expected_sym := g.table.get_type_symbol(node.typ) - g.write('typ, $sidx) /*expected idx: $sidx, name: $expected_sym.name */ ') + g.write('_typ, $sidx) /*expected idx: $sidx, name: $expected_sym.name */ ') // fill as cast name table for variant in expr_type_sym.info.variants { @@ -5739,7 +5767,7 @@ fn (mut g Gen) is_expr(node ast.InfixExpr) { g.write('_${c_name(sym.name)}_${c_name(sub_sym.name)}_index') return } else if sym.kind == .sum_type { - g.write('typ $eq ') + g.write('_typ $eq ') } g.expr(node.right) } diff --git a/vlib/v/gen/c/fn.v b/vlib/v/gen/c/fn.v index 852feb95ad..9d7cd99088 100644 --- a/vlib/v/gen/c/fn.v +++ b/vlib/v/gen/c/fn.v @@ -542,7 +542,7 @@ fn (mut g Gen) method_call(node ast.CallExpr) { g.write('tos3( /* $left_sym.name */ v_typeof_sumtype_${typ_sym.cname}( (') g.expr(node.left) dot := if node.left_type.is_ptr() { '->' } else { '.' } - g.write(')${dot}typ ))') + g.write(')${dot}_typ ))') return } if left_sym.kind == .interface_ && node.name == 'type_name' { diff --git a/vlib/v/parser/parser.v b/vlib/v/parser/parser.v index d53dd62708..01773c9c5c 100644 --- a/vlib/v/parser/parser.v +++ b/vlib/v/parser/parser.v @@ -2318,7 +2318,7 @@ fn (mut p Parser) type_decl() ast.TypeDecl { } variant_types := sum_variants.map(it.typ) prepend_mod_name := p.prepend_mod(name) - p.table.register_type_symbol(table.TypeSymbol{ + typ := p.table.register_type_symbol(table.TypeSymbol{ kind: .sum_type name: prepend_mod_name cname: util.no_dots(prepend_mod_name) @@ -2331,6 +2331,7 @@ fn (mut p Parser) type_decl() ast.TypeDecl { comments = p.eat_comments(same_line: true) return ast.SumTypeDecl{ name: name + typ: typ is_pub: is_pub variants: sum_variants pos: decl_pos diff --git a/vlib/v/table/table.v b/vlib/v/table/table.v index 8149b86652..83b38cedc1 100644 --- a/vlib/v/table/table.v +++ b/vlib/v/table/table.v @@ -277,20 +277,32 @@ pub fn (t &Table) find_field(s &TypeSymbol, name string) ?Field { // println('find_field($s.name, $name) types.len=$t.types.len s.parent_idx=$s.parent_idx') mut ts := s for { - if mut ts.info is Struct { - if field := ts.info.find_field(name) { + match mut ts.info { + Struct { + if field := ts.info.find_field(name) { + return field + } + } + Aggregate { + if field := ts.info.find_field(name) { + return field + } + field := t.register_aggregate_field(mut ts, name) or { return err } return field } - } else if mut ts.info is Aggregate { - if field := ts.info.find_field(name) { - return field + Interface { + if field := ts.info.find_field(name) { + return field + } } - field := t.register_aggregate_field(mut ts, name) or { return err } - return field - } else if mut ts.info is Interface { - if field := ts.info.find_field(name) { - return field + SumType { + t.resolve_common_sumtype_fields(s) + if field := ts.info.find_field(name) { + return field + } + return error('field `$name` does not exist or have the same type in all sumtype variants') } + else {} } if ts.parent_idx == 0 { break @@ -326,6 +338,46 @@ pub fn (t &Table) find_field_with_embeds(sym &TypeSymbol, field_name string) ?Fi } } +pub fn (t &Table) resolve_common_sumtype_fields(sym_ &TypeSymbol) { + mut sym := sym_ + mut info := sym.info as SumType + if info.found_fields { + return + } + mut field_map := map[string]Field{} + mut field_usages := map[string]int{} + for variant in info.variants { + mut v_sym := t.get_type_symbol(variant) + fields := match mut v_sym.info { + Struct { + v_sym.info.fields + } + SumType { + t.resolve_common_sumtype_fields(v_sym) + v_sym.info.fields + } + else { + []Field{} + } + } + for field in fields { + if field.name !in field_map { + field_map[field.name] = field + field_usages[field.name]++ + } else if field.equals(field_map[field.name]) { + field_usages[field.name]++ + } + } + } + for field, nr_definitions in field_usages { + if nr_definitions == info.variants.len { + info.fields << field_map[field] + } + } + info.found_fields = true + sym.info = info +} + [inline] pub fn (t &Table) find_type_idx(name string) int { return t.type_idxs[name] diff --git a/vlib/v/table/types.v b/vlib/v/table/types.v index 9384697a54..22390896ed 100644 --- a/vlib/v/table/types.v +++ b/vlib/v/table/types.v @@ -713,7 +713,7 @@ pub mut: is_global bool } -fn (f &Field) equals(o &Field) bool { +pub fn (f &Field) equals(o &Field) bool { // TODO: f.is_mut == o.is_mut was removed here to allow read only access // to (mut/not mut), but otherwise equal fields; some other new checks are needed: // - if node is declared mut, and we mutate node.stmts, all stmts fields must be mutable @@ -755,6 +755,9 @@ pub mut: pub struct SumType { pub: variants []Type +pub mut: + fields []Field + found_fields bool } // human readable type name @@ -971,6 +974,7 @@ pub fn (t &TypeSymbol) find_field(name string) ?Field { Aggregate { return t.info.find_field(name) } Struct { return t.info.find_field(name) } Interface { return t.info.find_field(name) } + SumType { return t.info.find_field(name) } else { return none } } } @@ -1025,6 +1029,15 @@ pub fn (s Struct) get_field(name string) Field { panic('unknown field `$name`') } +pub fn (s &SumType) find_field(name string) ?Field { + for field in s.fields { + if field.name == name { + return field + } + } + return none +} + pub fn (i Interface) defines_method(name string) bool { for method in i.methods { if method.name == name { diff --git a/vlib/v/tests/sum_type_common_fields_test.v b/vlib/v/tests/sum_type_common_fields_test.v new file mode 100644 index 0000000000..e935fbe810 --- /dev/null +++ b/vlib/v/tests/sum_type_common_fields_test.v @@ -0,0 +1,48 @@ +type Master = Sub1 | Sub2 + +struct Sub1 { +mut: + val int + name string +} + +struct Sub2 { + name string + val int +} + +struct Sub3 { + name string + val int +} + +type Master2 = Master | Sub3 + +fn test_common_sumtype_field_access() { + mut out := []Master{} + out << Sub1{ + val: 1 + name: 'one' + } + out << Sub2{ + val: 2 + name: 'two' + } + out << Sub2{ + val: 3 + name: 'three' + } + assert out[0].val == 1 + assert out[0].name == 'one' + + assert out[1].val == 2 + assert out[1].name == 'two' + + assert out[2].val == 3 + assert out[2].name == 'three' + + mut out0 := Master2(out[0]) // common fields on a doubly-wrapped sumtype + assert out0.val == 1 + assert out0.name == 'one' +} +