From 3b116d2455a92f1747daa24e26877dc59914ec45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20D=C3=A4schle?= Date: Fri, 6 Aug 2021 20:26:19 +0200 Subject: [PATCH] cgen: support embedded struct fields on sumtype common fields (#11084) --- vlib/v/ast/table.v | 75 +++++++++++++++++++++++++++++++++++- vlib/v/gen/c/cgen.v | 29 +++++++++++++- vlib/v/tests/sum_type_test.v | 47 +++++++++++++++++++++- 3 files changed, 148 insertions(+), 3 deletions(-) diff --git a/vlib/v/ast/table.v b/vlib/v/ast/table.v index f97238a941..ae8ae054e8 100644 --- a/vlib/v/ast/table.v +++ b/vlib/v/ast/table.v @@ -344,6 +344,28 @@ pub fn (t &Table) type_find_method(s &TypeSymbol, name string) ?Fn { return none } +pub struct GetEmbedsOptions { + preceding []Type +} + +// get_embeds returns all nested embedded structs +// the hierarchy of embeds is returned as a list +pub fn (t &Table) get_embeds(sym &TypeSymbol, options GetEmbedsOptions) [][]Type { + mut embeds := [][]Type{} + if sym.info is Struct { + for embed in sym.info.embeds { + embed_sym := t.get_type_symbol(embed) + mut preceding := options.preceding + preceding << embed + embeds << t.get_embeds(embed_sym, preceding: preceding) + } + if sym.info.embeds.len == 0 && options.preceding.len > 0 { + embeds << options.preceding + } + } + return embeds +} + pub fn (t &Table) type_find_method_from_embeds(sym &TypeSymbol, method_name string) ?(Fn, Type) { if sym.info is Struct { mut found_methods := []Fn{} @@ -407,6 +429,20 @@ pub fn (t &Table) struct_has_field(struct_ &TypeSymbol, name string) bool { return false } +// struct_fields returns all fields including fields from embeds +// use this instead symbol.info.fields to get all fields +pub fn (t &Table) struct_fields(sym &TypeSymbol) []StructField { + mut fields := []StructField{} + if sym.info is Struct { + fields << sym.info.fields + for embed in sym.info.embeds { + embed_sym := t.get_type_symbol(embed) + fields << t.struct_fields(embed_sym) + } + } + return fields +} + // search from current type up through each parent looking for field pub fn (t &Table) find_field(s &TypeSymbol, name string) ?StructField { // println('find_field($s.name, $name) types.len=$t.types.len s.parent_idx=$s.parent_idx') @@ -447,6 +483,43 @@ pub fn (t &Table) find_field(s &TypeSymbol, name string) ?StructField { return none } +// find_field_from_embeds is the same as find_field_from_embeds but also looks into nested embeds +pub fn (t &Table) find_field_from_embeds_recursive(sym &TypeSymbol, field_name string) ?(StructField, []Type) { + if sym.info is Struct { + mut found_fields := []StructField{} + mut embeds_of_found_fields := [][]Type{} + for embed in sym.info.embeds { + embed_sym := t.get_type_symbol(embed) + if field := t.find_field(embed_sym, field_name) { + found_fields << field + embeds_of_found_fields << [embed] + } else { + field, types := t.find_field_from_embeds_recursive(embed_sym, field_name) or { + StructField{}, []Type{} + } + found_fields << field + embeds_of_found_fields << types + } + } + if found_fields.len == 1 { + return found_fields[0], embeds_of_found_fields[0] + } else if found_fields.len > 1 { + return error('ambiguous field `$field_name`') + } + } else if sym.info is Aggregate { + for typ in sym.info.types { + agg_sym := t.get_type_symbol(typ) + field, embed_types := t.find_field_from_embeds_recursive(agg_sym, field_name) or { + return err + } + if embed_types.len > 0 { + return field, embed_types + } + } + } + return none +} + // find_field_from_embeds finds and returns a field in the embeddings of a struct and the embedding type pub fn (t &Table) find_field_from_embeds(sym &TypeSymbol, field_name string) ?(StructField, Type) { if sym.info is Struct { @@ -500,7 +573,7 @@ pub fn (t &Table) resolve_common_sumtype_fields(sym_ &TypeSymbol) { mut v_sym := t.get_type_symbol(variant) fields := match mut v_sym.info { Struct { - v_sym.info.fields + t.struct_fields(v_sym) } SumType { t.resolve_common_sumtype_fields(v_sym) diff --git a/vlib/v/gen/c/cgen.v b/vlib/v/gen/c/cgen.v index 75d9fd02e4..0aac062121 100644 --- a/vlib/v/gen/c/cgen.v +++ b/vlib/v/gen/c/cgen.v @@ -1875,14 +1875,41 @@ fn (mut g Gen) write_sumtype_casting_fn(got_ ast.Type, exp_ ast.Type) { got_cname, exp_cname := got_sym.cname, exp_sym.cname sb.writeln('static inline $exp_cname ${got_cname}_to_sumtype_${exp_cname}($got_cname* x) {') sb.writeln('\t$got_cname* ptr = memdup(x, sizeof($got_cname));') + for embed_hierarchy in g.table.get_embeds(got_sym) { + // last embed in the hierarchy + mut embed_cname := '' + mut embed_name := '' + mut accessor := '&x->' + for j, embed in embed_hierarchy { + embed_sym := g.table.get_type_symbol(embed) + embed_cname = embed_sym.cname + embed_name = embed_sym.embed_name() + if j > 0 { + accessor += '.' + } + accessor += embed_name + } + // if the variable is not used, the C compiler will optimize it away + sb.writeln('\t$embed_cname* ${embed_name}_ptr = memdup($accessor, sizeof($embed_cname));') + } sb.write_string('\treturn ($exp_cname){ ._$got_cname = ptr, ._typ = ${g.type_sidx(got)}') for field in (exp_sym.info as ast.SumType).fields { + mut ptr := 'ptr' + mut type_cname := got_cname + _, embed_types := g.table.find_field_from_embeds_recursive(got_sym, field.name) or { + ast.StructField{}, []ast.Type{} + } + if embed_types.len > 0 { + embed_sym := g.table.get_type_symbol(embed_types.last()) + ptr = '${embed_sym.embed_name()}_ptr' + type_cname = embed_sym.cname + } field_styp := g.typ(field.typ) if got_sym.kind in [.sum_type, .interface_] { // the field is already a wrapped pointer; we shouldn't wrap it once again sb.write_string(', .$field.name = ptr->$field.name') } else { - sb.write_string(', .$field.name = ($field_styp*)((char*)ptr + __offsetof_ptr(ptr, $got_cname, $field.name))') + sb.write_string(', .$field.name = ($field_styp*)((char*)$ptr + __offsetof_ptr($ptr, $type_cname, $field.name))') } } sb.writeln('};\n}') diff --git a/vlib/v/tests/sum_type_test.v b/vlib/v/tests/sum_type_test.v index c1c262eb76..ae2358af1b 100644 --- a/vlib/v/tests/sum_type_test.v +++ b/vlib/v/tests/sum_type_test.v @@ -107,7 +107,7 @@ fn test_converting_down() { } fn test_assignment_and_push() { - mut expr1 := Expr{} + mut expr1 := Expr(IfExpr{}) mut arr1 := []Expr{} expr := IntegerLiteral{ val: '111' @@ -712,3 +712,48 @@ fn test_binary_search_tree() { deleted.sort() assert deleted == [0.0, 0.3, 0.6, 1.0] } + +struct Common { + a int + b int +} + +struct Common2 { + Common +} + +struct Aa { + Common + x int +} + +struct Bb { + Common + x int +} + +struct Cc { + a int +} + +struct Dd { + Common2 +} + +type MySum = Aa | Bb | Cc | Dd + +fn test_sumtype_access_embed_fields() { + a := MySum(Aa{ + a: 1 + }) + assert a.a == 1 +} + +fn test_sumtype_access_nested_embed_fields() { + a := MySum(Dd{ + Common2: Common2{ + a: 2 + } + }) + assert a.a == 2 +}