checker: allow accessing fields common to all sumtype members (#9201)

pull/9221/head
spaceface 2021-03-09 18:16:18 +01:00 committed by GitHub
parent eed6f7dbff
commit f1469a8761
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 330 additions and 69 deletions

View File

@ -879,6 +879,7 @@ pub:
is_pub bool is_pub bool
pos token.Position pos token.Position
comments []Comment comments []Comment
typ table.Type
pub mut: pub mut:
variants []SumTypeVariant variants []SumTypeVariant
} }

View File

@ -2293,7 +2293,7 @@ pub fn (mut c Checker) selector_expr(mut selector_expr ast.SelectorExpr) table.T
c.error('ambiguous field `$field_name`', selector_expr.pos) c.error('ambiguous field `$field_name`', selector_expr.pos)
} }
} }
if sym.kind == .aggregate { if sym.kind in [.aggregate, .sum_type] {
unknown_field_msg = err unknown_field_msg = err
} }
} }
@ -2321,7 +2321,7 @@ pub fn (mut c Checker) selector_expr(mut selector_expr ast.SelectorExpr) table.T
selector_expr.typ = field.typ selector_expr.typ = field.typ
return field.typ return field.typ
} }
if sym.kind !in [.struct_, .aggregate, .interface_] { if sym.kind !in [.struct_, .aggregate, .interface_, .sum_type] {
if sym.kind != .placeholder { if sym.kind != .placeholder {
c.error('`$sym.name` is not a struct', selector_expr.pos) c.error('`$sym.name` is not a struct', selector_expr.pos)
} }

View File

@ -0,0 +1,20 @@
vlib/v/checker/tests/sum_type_common_fields_alias_error.vv:35:14: error: field `name` does not exist or have the same type in all sumtype variants
33 | }
34 | println(m)
35 | assert m[0].name == 'abc'
| ~~~~
36 | assert m[1].name == 'def'
37 | assert m[2].name == 'xyz'
vlib/v/checker/tests/sum_type_common_fields_alias_error.vv:36:14: error: field `name` does not exist or have the same type in all sumtype variants
34 | println(m)
35 | assert m[0].name == 'abc'
36 | assert m[1].name == 'def'
| ~~~~
37 | assert m[2].name == 'xyz'
38 | }
vlib/v/checker/tests/sum_type_common_fields_alias_error.vv:37:14: error: field `name` does not exist or have the same type in all sumtype variants
35 | assert m[0].name == 'abc'
36 | assert m[1].name == 'def'
37 | assert m[2].name == 'xyz'
| ~~~~
38 | }

View File

@ -0,0 +1,38 @@
type Main = Sub1 | Sub2 | Sub3
// NB: the subtypes will have a common `name` field, of the same `string`
// type, except Sub3, which has `name` of type AliasedString.
type AliasedString = string
struct Sub1 {
mut:
name string
}
struct Sub2 {
mut:
name string
}
struct Sub3 {
mut:
name AliasedString
}
fn main() {
mut m := []Main{}
m << Sub1{
name: 'abc'
}
m << Sub2{
name: 'def'
}
m << Sub3{
name: 'xyz'
}
println(m)
assert m[0].name == 'abc'
assert m[1].name == 'def'
assert m[2].name == 'xyz'
}

View File

@ -0,0 +1,6 @@
vlib/v/checker/tests/sum_type_common_fields_error.vv:53:14: error: field `val` does not exist or have the same type in all sumtype variants
51 | assert m[2].name == '64bit integer'
52 | assert m[3].name == 'string'
53 | assert m[0].val == 123
| ~~~
54 | }

View File

@ -0,0 +1,54 @@
type Main = Sub1 | Sub2 | Sub3 | Sub4
// NB: all subtypes have a common name field, of the same `string` type
// but they also have a field `val` that is of a different type in the
// different subtypes => accessing `m[0].name` is fine, but *not* `m[0].val`
struct Sub1 {
mut:
val int
name string
}
struct Sub2 {
mut:
val f32
name string
}
struct Sub3 {
mut:
val i64
name string
}
struct Sub4 {
mut:
val string
name string
}
fn main() {
mut m := []Main{}
m << Sub1{
val: 123
name: 'integer'
}
m << Sub2{
val: 3.14
name: 'float'
}
m << Sub3{
val: 9_876_543_210
name: '64bit integer'
}
m << Sub4{
val: 'abcd'
name: 'string'
}
println(m)
assert m[0].name == 'integer'
assert m[1].name == 'float'
assert m[2].name == '64bit integer'
assert m[3].name == 'string'
assert m[0].val == 123
}

View File

@ -677,7 +677,7 @@ fn (mut g Gen) gen_str_for_union_sum_type(info table.SumType, styp string, str_f
clean_sum_type_v_type_name = '&' + clean_sum_type_v_type_name.replace('*', '') clean_sum_type_v_type_name = '&' + clean_sum_type_v_type_name.replace('*', '')
} }
clean_sum_type_v_type_name = util.strip_main_name(clean_sum_type_v_type_name) clean_sum_type_v_type_name = util.strip_main_name(clean_sum_type_v_type_name)
g.auto_str_funcs.writeln('\tswitch(x.typ) {') g.auto_str_funcs.writeln('\tswitch(x._typ) {')
for typ in info.variants { for typ in info.variants {
mut value_fmt := '%.*s\\000' mut value_fmt := '%.*s\\000'
if typ == table.string_type { if typ == table.string_type {

View File

@ -103,6 +103,7 @@ mut:
alias_fn_definitions []string // alias equality functions that have been defined alias_fn_definitions []string // alias equality functions that have been defined
auto_fn_definitions []string // auto generated functions defination list auto_fn_definitions []string // auto generated functions defination list
anon_fn_definitions []string // anon generated functions defination list anon_fn_definitions []string // anon generated functions defination list
sumtype_definitions map[int]bool // `_TypeA_to_sumtype_TypeB()` fns that have been generated
is_json_fn bool // inside json.encode() is_json_fn bool // inside json.encode()
json_types []string // to avoid json gen duplicates json_types []string // to avoid json gen duplicates
pcs []ProfileCounterMeta // -prof profile counter fn_names => fn counter name pcs []ProfileCounterMeta // -prof profile counter fn_names => fn counter name
@ -1553,6 +1554,33 @@ fn (mut g Gen) for_in_stmt(node ast.ForInStmt) {
} }
} }
fn (mut g Gen) write_sumtype_casting_fn(got_ table.Type, exp_ table.Type) {
got, exp := got_.idx(), exp_.idx()
i := got | (exp << 16)
if got == exp || g.sumtype_definitions[i] {
return
}
g.sumtype_definitions[i] = true
got_sym := g.table.get_type_symbol(got)
exp_sym := g.table.get_type_symbol(exp)
mut sb := strings.new_builder(128)
got_cname, exp_cname := got_sym.cname, exp_sym.cname
sb.writeln('static inline $exp_cname ${got_cname}_to_sumtype_${exp_cname}($got_cname* x) {')
sb.writeln('\t$got_cname* ptr = memdup(x, sizeof($got_cname));')
sb.write_string('\treturn ($exp_cname){ ._$got_cname = ptr, ._typ = ${g.type_sidx(got)}')
for field in (exp_sym.info as table.SumType).fields {
field_styp := g.typ(field.typ)
if got_sym.kind in [.sum_type, .interface_] {
// the field is already a wrapped pointer; we shouldn't wrap it once again
sb.write_string(', .$field.name = ptr->$field.name')
} else {
sb.write_string(', .$field.name = ($field_styp*)((char*)ptr + __offsetof($got_cname, $field.name))')
}
}
sb.writeln('};\n}')
g.auto_fn_definitions << sb.str()
}
// use instead of expr() when you need to cast to a different type // use instead of expr() when you need to cast to a different type
fn (mut g Gen) expr_with_cast(expr ast.Expr, got_type_raw table.Type, expected_type table.Type) { fn (mut g Gen) expr_with_cast(expr ast.Expr, got_type_raw table.Type, expected_type table.Type) {
got_type := g.table.mktyp(got_type_raw) got_type := g.table.mktyp(got_type_raw)
@ -1590,24 +1618,6 @@ fn (mut g Gen) expr_with_cast(expr ast.Expr, got_type_raw table.Type, expected_t
expected_deref_type := if expected_is_ptr { expected_type.deref() } else { expected_type } expected_deref_type := if expected_is_ptr { expected_type.deref() } else { expected_type }
got_deref_type := if got_is_ptr { got_type.deref() } else { got_type } got_deref_type := if got_is_ptr { got_type.deref() } else { got_type }
if g.table.sumtype_has_variant(expected_deref_type, got_deref_type) { if g.table.sumtype_has_variant(expected_deref_type, got_deref_type) {
// got_idx := got_type.idx()
got_sidx := g.type_sidx(got_type)
// TODO: do we need 1-3?
if expected_is_ptr && got_is_ptr {
exp_der_styp := g.typ(expected_deref_type)
g.write('/* sum type cast 1 */ ($exp_styp) memdup(&($exp_der_styp){._$got_sym.cname = ')
g.expr(expr)
g.write(', .typ = $got_sidx /* $got_sym.name */}, sizeof($exp_der_styp))')
} else if expected_is_ptr {
exp_der_styp := g.typ(expected_deref_type)
g.write('/* sum type cast 2 */ ($exp_styp) memdup(&($exp_der_styp){._$got_sym.cname = memdup(&($got_styp[]){')
g.expr(expr)
g.write('}, sizeof($got_styp)), .typ = $got_sidx /* $got_sym.name */}, sizeof($exp_der_styp))')
} else if got_is_ptr {
g.write('/* sum type cast 3 */ ($exp_styp){._$got_sym.cname = ')
g.expr(expr)
g.write(', .typ = $got_sidx /* $got_sym.name */}')
} else {
mut is_already_sum_type := false mut is_already_sum_type := false
scope := g.file.scope.innermost(expr.position().pos) scope := g.file.scope.innermost(expr.position().pos)
if expr is ast.Ident { if expr is ast.Ident {
@ -1626,9 +1636,21 @@ fn (mut g Gen) expr_with_cast(expr ast.Expr, got_type_raw table.Type, expected_t
g.prevent_sum_type_unwrapping_once = true g.prevent_sum_type_unwrapping_once = true
g.expr(expr) g.expr(expr)
} else { } else {
g.write('/* sum type cast 4 */ ($exp_styp){._$got_sym.cname = memdup(&($got_styp[]){') g.write_sumtype_casting_fn(got_type, expected_type)
if expected_is_ptr {
g.write('memdup(&($exp_sym.cname[]){')
}
g.write('${got_sym.cname}_to_sumtype_${exp_sym.cname}(')
if !got_is_ptr {
g.write('(&(($got_styp[]){')
g.expr(expr) g.expr(expr)
g.write('}, sizeof($got_styp)), .typ = $got_sidx /* $got_sym.name */}') g.write('}[0])))')
} else {
g.expr(expr)
g.write(')')
}
if expected_is_ptr {
g.write('}, sizeof($exp_sym.cname))')
} }
} }
return return
@ -2707,7 +2729,7 @@ fn (mut g Gen) expr(node ast.Expr) {
g.concat_expr(node) g.concat_expr(node)
} }
ast.CTempVar { ast.CTempVar {
// g.write('/*ctmp .orig: $node.orig.str() , .typ: $node.typ, .is_ptr: $node.is_ptr */ ') // g.write('/*ctmp .orig: $node.orig.str() , ._typ: $node.typ, .is_ptr: $node.is_ptr */ ')
g.write(node.name) g.write(node.name)
} }
ast.EnumVal { ast.EnumVal {
@ -2923,7 +2945,7 @@ fn (mut g Gen) typeof_expr(node ast.TypeOf) {
// because the subtype of the expression may change: // because the subtype of the expression may change:
g.write('tos3( /* $sym.name */ v_typeof_sumtype_${sym.cname}( (') g.write('tos3( /* $sym.name */ v_typeof_sumtype_${sym.cname}( (')
g.expr(node.expr) g.expr(node.expr)
g.write(').typ ))') g.write(')._typ ))')
} else if sym.kind == .array_fixed { } else if sym.kind == .array_fixed {
fixed_info := sym.info as table.ArrayFixed fixed_info := sym.info as table.ArrayFixed
typ_name := g.table.get_type_name(fixed_info.elem_type) typ_name := g.table.get_type_name(fixed_info.elem_type)
@ -2956,7 +2978,7 @@ fn (mut g Gen) selector_expr(node ast.SelectorExpr) {
opt_base_typ := g.base_type(node.expr_type) opt_base_typ := g.base_type(node.expr_type)
g.writeln('(*($opt_base_typ*)') g.writeln('(*($opt_base_typ*)')
} }
if sym.kind == .interface_ { if sym.kind in [.interface_, .sum_type] {
g.write('(*(') g.write('(*(')
} }
if sym.kind == .array_fixed { if sym.kind == .array_fixed {
@ -3037,7 +3059,7 @@ fn (mut g Gen) selector_expr(node ast.SelectorExpr) {
if sum_type_deref_field != '' { if sum_type_deref_field != '' {
g.write('$sum_type_dot$sum_type_deref_field)') g.write('$sum_type_dot$sum_type_deref_field)')
} }
if sym.kind == .interface_ { if sym.kind in [.interface_, .sum_type] {
g.write('))') g.write('))')
} }
} }
@ -3678,7 +3700,7 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str
g.write(cond_var) g.write(cond_var)
dot_or_ptr := if node.cond_type.is_ptr() { '->' } else { '.' } dot_or_ptr := if node.cond_type.is_ptr() { '->' } else { '.' }
if sym.kind == .sum_type { if sym.kind == .sum_type {
g.write('${dot_or_ptr}typ == ') g.write('${dot_or_ptr}_typ == ')
g.expr(branch.exprs[sumtype_index]) g.expr(branch.exprs[sumtype_index])
} else if sym.kind == .interface_ { } else if sym.kind == .interface_ {
typ := branch.exprs[sumtype_index] as ast.Type typ := branch.exprs[sumtype_index] as ast.Type
@ -5074,13 +5096,19 @@ fn (mut g Gen) write_types(types []table.TypeSymbol) {
g.type_definitions.writeln('// | ${variant:4d} = ${g.typ(variant.idx()):-20s}') g.type_definitions.writeln('// | ${variant:4d} = ${g.typ(variant.idx()):-20s}')
} }
g.type_definitions.writeln('struct $name {') g.type_definitions.writeln('struct $name {')
g.type_definitions.writeln(' union {') g.type_definitions.writeln('\tunion {')
for variant in typ.info.variants { for variant in typ.info.variants {
variant_sym := g.table.get_type_symbol(variant) variant_sym := g.table.get_type_symbol(variant)
g.type_definitions.writeln(' ${g.typ(variant.to_ptr())} _$variant_sym.cname;') g.type_definitions.writeln('\t\t${g.typ(variant.to_ptr())} _$variant_sym.cname;')
}
g.type_definitions.writeln('\t};')
g.type_definitions.writeln('\tint _typ;')
if typ.info.fields.len > 0 {
g.writeln('\t// pointers to common sumtype fields')
for field in typ.info.fields {
g.type_definitions.writeln('\t${g.typ(field.typ.to_ptr())} $field.name;')
}
} }
g.type_definitions.writeln(' };')
g.type_definitions.writeln(' int typ;')
g.type_definitions.writeln('};') g.type_definitions.writeln('};')
g.type_definitions.writeln('') g.type_definitions.writeln('')
} }
@ -5691,7 +5719,7 @@ fn (mut g Gen) as_cast(node ast.AsCast) {
// g.write('typ, /*expected:*/$node.typ)') // g.write('typ, /*expected:*/$node.typ)')
sidx := g.type_sidx(node.typ) sidx := g.type_sidx(node.typ)
expected_sym := g.table.get_type_symbol(node.typ) expected_sym := g.table.get_type_symbol(node.typ)
g.write('typ, $sidx) /*expected idx: $sidx, name: $expected_sym.name */ ') g.write('_typ, $sidx) /*expected idx: $sidx, name: $expected_sym.name */ ')
// fill as cast name table // fill as cast name table
for variant in expr_type_sym.info.variants { for variant in expr_type_sym.info.variants {
@ -5739,7 +5767,7 @@ fn (mut g Gen) is_expr(node ast.InfixExpr) {
g.write('_${c_name(sym.name)}_${c_name(sub_sym.name)}_index') g.write('_${c_name(sym.name)}_${c_name(sub_sym.name)}_index')
return return
} else if sym.kind == .sum_type { } else if sym.kind == .sum_type {
g.write('typ $eq ') g.write('_typ $eq ')
} }
g.expr(node.right) g.expr(node.right)
} }

View File

@ -542,7 +542,7 @@ fn (mut g Gen) method_call(node ast.CallExpr) {
g.write('tos3( /* $left_sym.name */ v_typeof_sumtype_${typ_sym.cname}( (') g.write('tos3( /* $left_sym.name */ v_typeof_sumtype_${typ_sym.cname}( (')
g.expr(node.left) g.expr(node.left)
dot := if node.left_type.is_ptr() { '->' } else { '.' } dot := if node.left_type.is_ptr() { '->' } else { '.' }
g.write(')${dot}typ ))') g.write(')${dot}_typ ))')
return return
} }
if left_sym.kind == .interface_ && node.name == 'type_name' { if left_sym.kind == .interface_ && node.name == 'type_name' {

View File

@ -2318,7 +2318,7 @@ fn (mut p Parser) type_decl() ast.TypeDecl {
} }
variant_types := sum_variants.map(it.typ) variant_types := sum_variants.map(it.typ)
prepend_mod_name := p.prepend_mod(name) prepend_mod_name := p.prepend_mod(name)
p.table.register_type_symbol(table.TypeSymbol{ typ := p.table.register_type_symbol(table.TypeSymbol{
kind: .sum_type kind: .sum_type
name: prepend_mod_name name: prepend_mod_name
cname: util.no_dots(prepend_mod_name) cname: util.no_dots(prepend_mod_name)
@ -2331,6 +2331,7 @@ fn (mut p Parser) type_decl() ast.TypeDecl {
comments = p.eat_comments(same_line: true) comments = p.eat_comments(same_line: true)
return ast.SumTypeDecl{ return ast.SumTypeDecl{
name: name name: name
typ: typ
is_pub: is_pub is_pub: is_pub
variants: sum_variants variants: sum_variants
pos: decl_pos pos: decl_pos

View File

@ -277,21 +277,33 @@ pub fn (t &Table) find_field(s &TypeSymbol, name string) ?Field {
// println('find_field($s.name, $name) types.len=$t.types.len s.parent_idx=$s.parent_idx') // println('find_field($s.name, $name) types.len=$t.types.len s.parent_idx=$s.parent_idx')
mut ts := s mut ts := s
for { for {
if mut ts.info is Struct { match mut ts.info {
Struct {
if field := ts.info.find_field(name) { if field := ts.info.find_field(name) {
return field return field
} }
} else if mut ts.info is Aggregate { }
Aggregate {
if field := ts.info.find_field(name) { if field := ts.info.find_field(name) {
return field return field
} }
field := t.register_aggregate_field(mut ts, name) or { return err } field := t.register_aggregate_field(mut ts, name) or { return err }
return field return field
} else if mut ts.info is Interface { }
Interface {
if field := ts.info.find_field(name) { if field := ts.info.find_field(name) {
return field return field
} }
} }
SumType {
t.resolve_common_sumtype_fields(s)
if field := ts.info.find_field(name) {
return field
}
return error('field `$name` does not exist or have the same type in all sumtype variants')
}
else {}
}
if ts.parent_idx == 0 { if ts.parent_idx == 0 {
break break
} }
@ -326,6 +338,46 @@ pub fn (t &Table) find_field_with_embeds(sym &TypeSymbol, field_name string) ?Fi
} }
} }
pub fn (t &Table) resolve_common_sumtype_fields(sym_ &TypeSymbol) {
mut sym := sym_
mut info := sym.info as SumType
if info.found_fields {
return
}
mut field_map := map[string]Field{}
mut field_usages := map[string]int{}
for variant in info.variants {
mut v_sym := t.get_type_symbol(variant)
fields := match mut v_sym.info {
Struct {
v_sym.info.fields
}
SumType {
t.resolve_common_sumtype_fields(v_sym)
v_sym.info.fields
}
else {
[]Field{}
}
}
for field in fields {
if field.name !in field_map {
field_map[field.name] = field
field_usages[field.name]++
} else if field.equals(field_map[field.name]) {
field_usages[field.name]++
}
}
}
for field, nr_definitions in field_usages {
if nr_definitions == info.variants.len {
info.fields << field_map[field]
}
}
info.found_fields = true
sym.info = info
}
[inline] [inline]
pub fn (t &Table) find_type_idx(name string) int { pub fn (t &Table) find_type_idx(name string) int {
return t.type_idxs[name] return t.type_idxs[name]

View File

@ -713,7 +713,7 @@ pub mut:
is_global bool is_global bool
} }
fn (f &Field) equals(o &Field) bool { pub fn (f &Field) equals(o &Field) bool {
// TODO: f.is_mut == o.is_mut was removed here to allow read only access // TODO: f.is_mut == o.is_mut was removed here to allow read only access
// to (mut/not mut), but otherwise equal fields; some other new checks are needed: // to (mut/not mut), but otherwise equal fields; some other new checks are needed:
// - if node is declared mut, and we mutate node.stmts, all stmts fields must be mutable // - if node is declared mut, and we mutate node.stmts, all stmts fields must be mutable
@ -755,6 +755,9 @@ pub mut:
pub struct SumType { pub struct SumType {
pub: pub:
variants []Type variants []Type
pub mut:
fields []Field
found_fields bool
} }
// human readable type name // human readable type name
@ -971,6 +974,7 @@ pub fn (t &TypeSymbol) find_field(name string) ?Field {
Aggregate { return t.info.find_field(name) } Aggregate { return t.info.find_field(name) }
Struct { return t.info.find_field(name) } Struct { return t.info.find_field(name) }
Interface { return t.info.find_field(name) } Interface { return t.info.find_field(name) }
SumType { return t.info.find_field(name) }
else { return none } else { return none }
} }
} }
@ -1025,6 +1029,15 @@ pub fn (s Struct) get_field(name string) Field {
panic('unknown field `$name`') panic('unknown field `$name`')
} }
pub fn (s &SumType) find_field(name string) ?Field {
for field in s.fields {
if field.name == name {
return field
}
}
return none
}
pub fn (i Interface) defines_method(name string) bool { pub fn (i Interface) defines_method(name string) bool {
for method in i.methods { for method in i.methods {
if method.name == name { if method.name == name {

View File

@ -0,0 +1,48 @@
type Master = Sub1 | Sub2
struct Sub1 {
mut:
val int
name string
}
struct Sub2 {
name string
val int
}
struct Sub3 {
name string
val int
}
type Master2 = Master | Sub3
fn test_common_sumtype_field_access() {
mut out := []Master{}
out << Sub1{
val: 1
name: 'one'
}
out << Sub2{
val: 2
name: 'two'
}
out << Sub2{
val: 3
name: 'three'
}
assert out[0].val == 1
assert out[0].name == 'one'
assert out[1].val == 2
assert out[1].name == 'two'
assert out[2].val == 3
assert out[2].name == 'three'
mut out0 := Master2(out[0]) // common fields on a doubly-wrapped sumtype
assert out0.val == 1
assert out0.name == 'one'
}