diff --git a/vlib/v/ast/table.v b/vlib/v/ast/table.v index bb30ef0cda..341c7c3189 100644 --- a/vlib/v/ast/table.v +++ b/vlib/v/ast/table.v @@ -321,7 +321,7 @@ pub fn (t &Table) type_has_method(s &TypeSymbol, name string) bool { return false } -// search from current type up through each parent looking for method +// type_find_method searches from current type up through each parent looking for method pub fn (t &Table) type_find_method(s &TypeSymbol, name string) ?Fn { // println('type_find_method($s.name, $name) types.len=$t.types.len s.parent_idx=$s.parent_idx') mut ts := unsafe { s } @@ -340,6 +340,36 @@ pub fn (t &Table) type_find_method(s &TypeSymbol, name string) ?Fn { return none } +pub fn (t &Table) type_find_method_from_embeds(sym &TypeSymbol, method_name string) ?(Fn, Type) { + if sym.info is Struct { + mut found_methods := []Fn{} + mut embed_of_found_methods := []Type{} + for embed in sym.info.embeds { + embed_sym := t.get_type_symbol(embed) + if m := t.type_find_method(embed_sym, method_name) { + found_methods << m + embed_of_found_methods << embed + } + } + if found_methods.len == 1 { + return found_methods[0], embed_of_found_methods[0] + } else if found_methods.len > 1 { + return error('ambiguous method `$method_name`') + } + } else if sym.info is Aggregate { + for typ in sym.info.types { + agg_sym := t.get_type_symbol(typ) + method, embed_type := t.type_find_method_from_embeds(agg_sym, method_name) or { + return err + } + if embed_type != 0 { + return method, embed_type + } + } + } + return none +} + fn (t &Table) register_aggregate_field(mut sym TypeSymbol, name string) ?StructField { if sym.kind != .aggregate { t.panic('Unexpected type symbol: $sym.kind') @@ -347,9 +377,7 @@ fn (t &Table) register_aggregate_field(mut sym TypeSymbol, name string) ?StructF mut agg_info := sym.info as Aggregate // an aggregate always has at least 2 types mut found_once := false - mut new_field := StructField{ - // default_expr: ast.empty_expr() - } + mut new_field := StructField{} for typ in agg_info.types { ts := t.get_type_symbol(typ) if type_field := t.find_field(ts, name) { @@ -415,29 +443,44 @@ pub fn (t &Table) find_field(s &TypeSymbol, name string) ?StructField { return none } -// find_field_with_embeds searches for a given field, also looking through embedded fields -pub fn (t &Table) find_field_with_embeds(sym &TypeSymbol, field_name string) ?StructField { - if f := t.find_field(sym, field_name) { - return f - } else { - // look for embedded field - if sym.info is Struct { - mut found_fields := []StructField{} - mut embed_of_found_fields := []Type{} - for embed in sym.info.embeds { - embed_sym := t.get_type_symbol(embed) - if f := t.find_field(embed_sym, field_name) { - found_fields << f - embed_of_found_fields << embed - } - } - if found_fields.len == 1 { - return found_fields[0] - } else if found_fields.len > 1 { - return error('ambiguous field `$field_name`') +// find_field_from_embeds finds and returns a field in the embeddings of a struct and the embedding type +pub fn (t &Table) find_field_from_embeds(sym &TypeSymbol, field_name string) ?(StructField, Type) { + if sym.info is Struct { + mut found_fields := []StructField{} + mut embed_of_found_fields := []Type{} + for embed in sym.info.embeds { + embed_sym := t.get_type_symbol(embed) + if field := t.find_field(embed_sym, field_name) { + found_fields << field + embed_of_found_fields << embed } } - return err + if found_fields.len == 1 { + return found_fields[0], embed_of_found_fields[0] + } else if found_fields.len > 1 { + return error('ambiguous field `$field_name`') + } + } else if sym.info is Aggregate { + for typ in sym.info.types { + agg_sym := t.get_type_symbol(typ) + field, embed_type := t.find_field_from_embeds(agg_sym, field_name) or { return err } + if embed_type != 0 { + return field, embed_type + } + } + } + return none +} + +// find_field_with_embeds searches for a given field, also looking through embedded fields +pub fn (t &Table) find_field_with_embeds(sym &TypeSymbol, field_name string) ?StructField { + if field := t.find_field(sym, field_name) { + return field + } else { + // look for embedded field + first_err := err + field, _ := t.find_field_from_embeds(sym, field_name) or { return first_err } + return field } } diff --git a/vlib/v/checker/checker.v b/vlib/v/checker/checker.v index d66dabde9f..cfe00b30b7 100644 --- a/vlib/v/checker/checker.v +++ b/vlib/v/checker/checker.v @@ -2118,7 +2118,6 @@ pub fn (mut c Checker) method_call(mut call_expr ast.CallExpr) ast.Type { method = m has_method = true } else { - // can this logic be moved to ast.type_find_method() so it can be used from anywhere if left_type_sym.info is ast.Struct { if left_type_sym.info.parent_type != 0 { type_sym := c.table.get_type_symbol(left_type_sym.info.parent_type) @@ -2127,24 +2126,21 @@ pub fn (mut c Checker) method_call(mut call_expr ast.CallExpr) ast.Type { has_method = true is_generic = true } - } else { - mut found_methods := []ast.Fn{} - mut embed_of_found_methods := []ast.Type{} - for embed in left_type_sym.info.embeds { - embed_sym := c.table.get_type_symbol(embed) - if m := c.table.type_find_method(embed_sym, method_name) { - found_methods << m - embed_of_found_methods << embed - } - } - if found_methods.len == 1 { - method = found_methods[0] - has_method = true - is_method_from_embed = true - call_expr.from_embed_type = embed_of_found_methods[0] - } else if found_methods.len > 1 { - c.error('ambiguous method `$method_name`', call_expr.pos) + } + } + if !has_method { + has_method = true + mut embed_type := ast.Type(0) + method, embed_type = c.table.type_find_method_from_embeds(left_type_sym, method_name) or { + if err.msg != '' { + c.error(err.msg, call_expr.pos) } + has_method = false + ast.Fn{}, ast.Type(0) + } + if embed_type != 0 { + is_method_from_embed = true + call_expr.from_embed_type = embed_type } } if left_type_sym.kind == .aggregate { @@ -3379,24 +3375,16 @@ pub fn (mut c Checker) selector_expr(mut node ast.SelectorExpr) ast.Type { field = f } else { // look for embedded field - if sym.info is ast.Struct { - mut found_fields := []ast.StructField{} - mut embed_of_found_fields := []ast.Type{} - for embed in sym.info.embeds { - embed_sym := c.table.get_type_symbol(embed) - if f := c.table.find_field(embed_sym, field_name) { - found_fields << f - embed_of_found_fields << embed - } - } - if found_fields.len == 1 { - field = found_fields[0] - has_field = true - node.from_embed_type = embed_of_found_fields[0] - } else if found_fields.len > 1 { - c.error('ambiguous field `$field_name`', node.pos) + has_field = true + mut embed_type := ast.Type(0) + field, embed_type = c.table.find_field_from_embeds(sym, field_name) or { + if err.msg != '' { + c.error(err.msg, node.pos) } + has_field = false + ast.StructField{}, ast.Type(0) } + node.from_embed_type = embed_type if sym.kind in [.aggregate, .sum_type] { unknown_field_msg = err.msg } @@ -3416,24 +3404,16 @@ pub fn (mut c Checker) selector_expr(mut node ast.SelectorExpr) ast.Type { field = f } else { // look for embedded field - if gs.info is ast.Struct { - mut found_fields := []ast.StructField{} - mut embed_of_found_fields := []ast.Type{} - for embed in gs.info.embeds { - embed_sym := c.table.get_type_symbol(embed) - if f := c.table.find_field(embed_sym, field_name) { - found_fields << f - embed_of_found_fields << embed - } - } - if found_fields.len == 1 { - field = found_fields[0] - has_field = true - node.from_embed_type = embed_of_found_fields[0] - } else if found_fields.len > 1 { - c.error('ambiguous field `$field_name`', node.pos) + has_field = true + mut embed_type := ast.Type(0) + field, embed_type = c.table.find_field_from_embeds(sym, field_name) or { + if err.msg != '' { + c.error(err.msg, node.pos) } + has_field = false + ast.StructField{}, ast.Type(0) } + node.from_embed_type = embed_type } } } diff --git a/vlib/v/gen/c/cgen.v b/vlib/v/gen/c/cgen.v index d437999af4..d7658d1b30 100644 --- a/vlib/v/gen/c/cgen.v +++ b/vlib/v/gen/c/cgen.v @@ -3625,7 +3625,7 @@ fn (mut g Gen) selector_expr(node ast.SelectorExpr) { g.write('.data)') } // struct embedding - if sym.info is ast.Struct { + if sym.info is ast.Struct || sym.info is ast.Aggregate { if node.from_embed_type != 0 { embed_sym := g.table.get_type_symbol(node.from_embed_type) embed_name := embed_sym.embed_name() diff --git a/vlib/v/tests/struct_embed_test.v b/vlib/v/tests/struct_embed_test.v index 46fef8eb96..79476c7f01 100644 --- a/vlib/v/tests/struct_embed_test.v +++ b/vlib/v/tests/struct_embed_test.v @@ -157,3 +157,68 @@ fn test_embed_method_generic() { mut app := App{} assert embed_method_generic(app) } + +type Piece = King | Queen + +struct Position { + x byte + y byte +} + +enum TeamEnum { + black + white +} + +struct PieceCommonFields { + pos Position + team TeamEnum +} + +fn (p PieceCommonFields) get_pos() Position { + return p.pos +} + +struct King { + PieceCommonFields +} + +struct Queen { + PieceCommonFields +} + +fn (piece Piece) position() Position { + mut pos := Position{} + match piece { + King, Queen { pos = piece.pos } + } + return pos +} + +fn (piece Piece) get_position() Position { + mut pos := Position{} + match piece { + King, Queen { pos = piece.get_pos() } + } + return pos +} + +fn test_match_aggregate_field() { + piece := Piece(King{ + pos: Position{1, 8} + team: .black + }) + pos := piece.position() + assert pos.x == 1 + assert pos.y == 8 +} + +fn test_match_aggregate_method() { + piece := Piece(King{ + pos: Position{1, 8} + team: .black + }) + pos := piece.get_position() + assert pos.x == 1 + assert pos.y == 8 +}