all: implement accessing embedded field and method of aggregate (#10907)

pull/10939/head
Daniel Däschle 2021-07-23 00:14:39 +02:00 committed by GitHub
parent f40090e8ff
commit 2c0c211c79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 164 additions and 76 deletions

View File

@ -321,7 +321,7 @@ pub fn (t &Table) type_has_method(s &TypeSymbol, name string) bool {
return false
}
// search from current type up through each parent looking for method
// type_find_method searches from current type up through each parent looking for method
pub fn (t &Table) type_find_method(s &TypeSymbol, name string) ?Fn {
// println('type_find_method($s.name, $name) types.len=$t.types.len s.parent_idx=$s.parent_idx')
mut ts := unsafe { s }
@ -340,6 +340,36 @@ pub fn (t &Table) type_find_method(s &TypeSymbol, name string) ?Fn {
return none
}
pub fn (t &Table) type_find_method_from_embeds(sym &TypeSymbol, method_name string) ?(Fn, Type) {
if sym.info is Struct {
mut found_methods := []Fn{}
mut embed_of_found_methods := []Type{}
for embed in sym.info.embeds {
embed_sym := t.get_type_symbol(embed)
if m := t.type_find_method(embed_sym, method_name) {
found_methods << m
embed_of_found_methods << embed
}
}
if found_methods.len == 1 {
return found_methods[0], embed_of_found_methods[0]
} else if found_methods.len > 1 {
return error('ambiguous method `$method_name`')
}
} else if sym.info is Aggregate {
for typ in sym.info.types {
agg_sym := t.get_type_symbol(typ)
method, embed_type := t.type_find_method_from_embeds(agg_sym, method_name) or {
return err
}
if embed_type != 0 {
return method, embed_type
}
}
}
return none
}
fn (t &Table) register_aggregate_field(mut sym TypeSymbol, name string) ?StructField {
if sym.kind != .aggregate {
t.panic('Unexpected type symbol: $sym.kind')
@ -347,9 +377,7 @@ fn (t &Table) register_aggregate_field(mut sym TypeSymbol, name string) ?StructF
mut agg_info := sym.info as Aggregate
// an aggregate always has at least 2 types
mut found_once := false
mut new_field := StructField{
// default_expr: ast.empty_expr()
}
mut new_field := StructField{}
for typ in agg_info.types {
ts := t.get_type_symbol(typ)
if type_field := t.find_field(ts, name) {
@ -415,29 +443,44 @@ pub fn (t &Table) find_field(s &TypeSymbol, name string) ?StructField {
return none
}
// find_field_with_embeds searches for a given field, also looking through embedded fields
pub fn (t &Table) find_field_with_embeds(sym &TypeSymbol, field_name string) ?StructField {
if f := t.find_field(sym, field_name) {
return f
} else {
// look for embedded field
if sym.info is Struct {
mut found_fields := []StructField{}
mut embed_of_found_fields := []Type{}
for embed in sym.info.embeds {
embed_sym := t.get_type_symbol(embed)
if f := t.find_field(embed_sym, field_name) {
found_fields << f
embed_of_found_fields << embed
}
}
if found_fields.len == 1 {
return found_fields[0]
} else if found_fields.len > 1 {
return error('ambiguous field `$field_name`')
// find_field_from_embeds finds and returns a field in the embeddings of a struct and the embedding type
pub fn (t &Table) find_field_from_embeds(sym &TypeSymbol, field_name string) ?(StructField, Type) {
if sym.info is Struct {
mut found_fields := []StructField{}
mut embed_of_found_fields := []Type{}
for embed in sym.info.embeds {
embed_sym := t.get_type_symbol(embed)
if field := t.find_field(embed_sym, field_name) {
found_fields << field
embed_of_found_fields << embed
}
}
return err
if found_fields.len == 1 {
return found_fields[0], embed_of_found_fields[0]
} else if found_fields.len > 1 {
return error('ambiguous field `$field_name`')
}
} else if sym.info is Aggregate {
for typ in sym.info.types {
agg_sym := t.get_type_symbol(typ)
field, embed_type := t.find_field_from_embeds(agg_sym, field_name) or { return err }
if embed_type != 0 {
return field, embed_type
}
}
}
return none
}
// find_field_with_embeds searches for a given field, also looking through embedded fields
pub fn (t &Table) find_field_with_embeds(sym &TypeSymbol, field_name string) ?StructField {
if field := t.find_field(sym, field_name) {
return field
} else {
// look for embedded field
first_err := err
field, _ := t.find_field_from_embeds(sym, field_name) or { return first_err }
return field
}
}

View File

@ -2118,7 +2118,6 @@ pub fn (mut c Checker) method_call(mut call_expr ast.CallExpr) ast.Type {
method = m
has_method = true
} else {
// can this logic be moved to ast.type_find_method() so it can be used from anywhere
if left_type_sym.info is ast.Struct {
if left_type_sym.info.parent_type != 0 {
type_sym := c.table.get_type_symbol(left_type_sym.info.parent_type)
@ -2127,24 +2126,21 @@ pub fn (mut c Checker) method_call(mut call_expr ast.CallExpr) ast.Type {
has_method = true
is_generic = true
}
} else {
mut found_methods := []ast.Fn{}
mut embed_of_found_methods := []ast.Type{}
for embed in left_type_sym.info.embeds {
embed_sym := c.table.get_type_symbol(embed)
if m := c.table.type_find_method(embed_sym, method_name) {
found_methods << m
embed_of_found_methods << embed
}
}
if found_methods.len == 1 {
method = found_methods[0]
has_method = true
is_method_from_embed = true
call_expr.from_embed_type = embed_of_found_methods[0]
} else if found_methods.len > 1 {
c.error('ambiguous method `$method_name`', call_expr.pos)
}
}
if !has_method {
has_method = true
mut embed_type := ast.Type(0)
method, embed_type = c.table.type_find_method_from_embeds(left_type_sym, method_name) or {
if err.msg != '' {
c.error(err.msg, call_expr.pos)
}
has_method = false
ast.Fn{}, ast.Type(0)
}
if embed_type != 0 {
is_method_from_embed = true
call_expr.from_embed_type = embed_type
}
}
if left_type_sym.kind == .aggregate {
@ -3379,24 +3375,16 @@ pub fn (mut c Checker) selector_expr(mut node ast.SelectorExpr) ast.Type {
field = f
} else {
// look for embedded field
if sym.info is ast.Struct {
mut found_fields := []ast.StructField{}
mut embed_of_found_fields := []ast.Type{}
for embed in sym.info.embeds {
embed_sym := c.table.get_type_symbol(embed)
if f := c.table.find_field(embed_sym, field_name) {
found_fields << f
embed_of_found_fields << embed
}
}
if found_fields.len == 1 {
field = found_fields[0]
has_field = true
node.from_embed_type = embed_of_found_fields[0]
} else if found_fields.len > 1 {
c.error('ambiguous field `$field_name`', node.pos)
has_field = true
mut embed_type := ast.Type(0)
field, embed_type = c.table.find_field_from_embeds(sym, field_name) or {
if err.msg != '' {
c.error(err.msg, node.pos)
}
has_field = false
ast.StructField{}, ast.Type(0)
}
node.from_embed_type = embed_type
if sym.kind in [.aggregate, .sum_type] {
unknown_field_msg = err.msg
}
@ -3416,24 +3404,16 @@ pub fn (mut c Checker) selector_expr(mut node ast.SelectorExpr) ast.Type {
field = f
} else {
// look for embedded field
if gs.info is ast.Struct {
mut found_fields := []ast.StructField{}
mut embed_of_found_fields := []ast.Type{}
for embed in gs.info.embeds {
embed_sym := c.table.get_type_symbol(embed)
if f := c.table.find_field(embed_sym, field_name) {
found_fields << f
embed_of_found_fields << embed
}
}
if found_fields.len == 1 {
field = found_fields[0]
has_field = true
node.from_embed_type = embed_of_found_fields[0]
} else if found_fields.len > 1 {
c.error('ambiguous field `$field_name`', node.pos)
has_field = true
mut embed_type := ast.Type(0)
field, embed_type = c.table.find_field_from_embeds(sym, field_name) or {
if err.msg != '' {
c.error(err.msg, node.pos)
}
has_field = false
ast.StructField{}, ast.Type(0)
}
node.from_embed_type = embed_type
}
}
}

View File

@ -3625,7 +3625,7 @@ fn (mut g Gen) selector_expr(node ast.SelectorExpr) {
g.write('.data)')
}
// struct embedding
if sym.info is ast.Struct {
if sym.info is ast.Struct || sym.info is ast.Aggregate {
if node.from_embed_type != 0 {
embed_sym := g.table.get_type_symbol(node.from_embed_type)
embed_name := embed_sym.embed_name()

View File

@ -157,3 +157,68 @@ fn test_embed_method_generic() {
mut app := App{}
assert embed_method_generic(app)
}
type Piece = King | Queen
struct Position {
x byte
y byte
}
enum TeamEnum {
black
white
}
struct PieceCommonFields {
pos Position
team TeamEnum
}
fn (p PieceCommonFields) get_pos() Position {
return p.pos
}
struct King {
PieceCommonFields
}
struct Queen {
PieceCommonFields
}
fn (piece Piece) position() Position {
mut pos := Position{}
match piece {
King, Queen { pos = piece.pos }
}
return pos
}
fn (piece Piece) get_position() Position {
mut pos := Position{}
match piece {
King, Queen { pos = piece.get_pos() }
}
return pos
}
fn test_match_aggregate_field() {
piece := Piece(King{
pos: Position{1, 8}
team: .black
})
pos := piece.position()
assert pos.x == 1
assert pos.y == 8
}
fn test_match_aggregate_method() {
piece := Piece(King{
pos: Position{1, 8}
team: .black
})
pos := piece.get_position()
assert pos.x == 1
assert pos.y == 8
}