all: implement accessing embedded field and method of aggregate (#10907)
parent
f40090e8ff
commit
2c0c211c79
|
@ -321,7 +321,7 @@ pub fn (t &Table) type_has_method(s &TypeSymbol, name string) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
// search from current type up through each parent looking for method
|
||||
// type_find_method searches from current type up through each parent looking for method
|
||||
pub fn (t &Table) type_find_method(s &TypeSymbol, name string) ?Fn {
|
||||
// println('type_find_method($s.name, $name) types.len=$t.types.len s.parent_idx=$s.parent_idx')
|
||||
mut ts := unsafe { s }
|
||||
|
@ -340,6 +340,36 @@ pub fn (t &Table) type_find_method(s &TypeSymbol, name string) ?Fn {
|
|||
return none
|
||||
}
|
||||
|
||||
pub fn (t &Table) type_find_method_from_embeds(sym &TypeSymbol, method_name string) ?(Fn, Type) {
|
||||
if sym.info is Struct {
|
||||
mut found_methods := []Fn{}
|
||||
mut embed_of_found_methods := []Type{}
|
||||
for embed in sym.info.embeds {
|
||||
embed_sym := t.get_type_symbol(embed)
|
||||
if m := t.type_find_method(embed_sym, method_name) {
|
||||
found_methods << m
|
||||
embed_of_found_methods << embed
|
||||
}
|
||||
}
|
||||
if found_methods.len == 1 {
|
||||
return found_methods[0], embed_of_found_methods[0]
|
||||
} else if found_methods.len > 1 {
|
||||
return error('ambiguous method `$method_name`')
|
||||
}
|
||||
} else if sym.info is Aggregate {
|
||||
for typ in sym.info.types {
|
||||
agg_sym := t.get_type_symbol(typ)
|
||||
method, embed_type := t.type_find_method_from_embeds(agg_sym, method_name) or {
|
||||
return err
|
||||
}
|
||||
if embed_type != 0 {
|
||||
return method, embed_type
|
||||
}
|
||||
}
|
||||
}
|
||||
return none
|
||||
}
|
||||
|
||||
fn (t &Table) register_aggregate_field(mut sym TypeSymbol, name string) ?StructField {
|
||||
if sym.kind != .aggregate {
|
||||
t.panic('Unexpected type symbol: $sym.kind')
|
||||
|
@ -347,9 +377,7 @@ fn (t &Table) register_aggregate_field(mut sym TypeSymbol, name string) ?StructF
|
|||
mut agg_info := sym.info as Aggregate
|
||||
// an aggregate always has at least 2 types
|
||||
mut found_once := false
|
||||
mut new_field := StructField{
|
||||
// default_expr: ast.empty_expr()
|
||||
}
|
||||
mut new_field := StructField{}
|
||||
for typ in agg_info.types {
|
||||
ts := t.get_type_symbol(typ)
|
||||
if type_field := t.find_field(ts, name) {
|
||||
|
@ -415,29 +443,44 @@ pub fn (t &Table) find_field(s &TypeSymbol, name string) ?StructField {
|
|||
return none
|
||||
}
|
||||
|
||||
// find_field_with_embeds searches for a given field, also looking through embedded fields
|
||||
pub fn (t &Table) find_field_with_embeds(sym &TypeSymbol, field_name string) ?StructField {
|
||||
if f := t.find_field(sym, field_name) {
|
||||
return f
|
||||
} else {
|
||||
// look for embedded field
|
||||
// find_field_from_embeds finds and returns a field in the embeddings of a struct and the embedding type
|
||||
pub fn (t &Table) find_field_from_embeds(sym &TypeSymbol, field_name string) ?(StructField, Type) {
|
||||
if sym.info is Struct {
|
||||
mut found_fields := []StructField{}
|
||||
mut embed_of_found_fields := []Type{}
|
||||
for embed in sym.info.embeds {
|
||||
embed_sym := t.get_type_symbol(embed)
|
||||
if f := t.find_field(embed_sym, field_name) {
|
||||
found_fields << f
|
||||
if field := t.find_field(embed_sym, field_name) {
|
||||
found_fields << field
|
||||
embed_of_found_fields << embed
|
||||
}
|
||||
}
|
||||
if found_fields.len == 1 {
|
||||
return found_fields[0]
|
||||
return found_fields[0], embed_of_found_fields[0]
|
||||
} else if found_fields.len > 1 {
|
||||
return error('ambiguous field `$field_name`')
|
||||
}
|
||||
} else if sym.info is Aggregate {
|
||||
for typ in sym.info.types {
|
||||
agg_sym := t.get_type_symbol(typ)
|
||||
field, embed_type := t.find_field_from_embeds(agg_sym, field_name) or { return err }
|
||||
if embed_type != 0 {
|
||||
return field, embed_type
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
return none
|
||||
}
|
||||
|
||||
// find_field_with_embeds searches for a given field, also looking through embedded fields
|
||||
pub fn (t &Table) find_field_with_embeds(sym &TypeSymbol, field_name string) ?StructField {
|
||||
if field := t.find_field(sym, field_name) {
|
||||
return field
|
||||
} else {
|
||||
// look for embedded field
|
||||
first_err := err
|
||||
field, _ := t.find_field_from_embeds(sym, field_name) or { return first_err }
|
||||
return field
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2118,7 +2118,6 @@ pub fn (mut c Checker) method_call(mut call_expr ast.CallExpr) ast.Type {
|
|||
method = m
|
||||
has_method = true
|
||||
} else {
|
||||
// can this logic be moved to ast.type_find_method() so it can be used from anywhere
|
||||
if left_type_sym.info is ast.Struct {
|
||||
if left_type_sym.info.parent_type != 0 {
|
||||
type_sym := c.table.get_type_symbol(left_type_sym.info.parent_type)
|
||||
|
@ -2127,24 +2126,21 @@ pub fn (mut c Checker) method_call(mut call_expr ast.CallExpr) ast.Type {
|
|||
has_method = true
|
||||
is_generic = true
|
||||
}
|
||||
} else {
|
||||
mut found_methods := []ast.Fn{}
|
||||
mut embed_of_found_methods := []ast.Type{}
|
||||
for embed in left_type_sym.info.embeds {
|
||||
embed_sym := c.table.get_type_symbol(embed)
|
||||
if m := c.table.type_find_method(embed_sym, method_name) {
|
||||
found_methods << m
|
||||
embed_of_found_methods << embed
|
||||
}
|
||||
}
|
||||
if found_methods.len == 1 {
|
||||
method = found_methods[0]
|
||||
if !has_method {
|
||||
has_method = true
|
||||
is_method_from_embed = true
|
||||
call_expr.from_embed_type = embed_of_found_methods[0]
|
||||
} else if found_methods.len > 1 {
|
||||
c.error('ambiguous method `$method_name`', call_expr.pos)
|
||||
mut embed_type := ast.Type(0)
|
||||
method, embed_type = c.table.type_find_method_from_embeds(left_type_sym, method_name) or {
|
||||
if err.msg != '' {
|
||||
c.error(err.msg, call_expr.pos)
|
||||
}
|
||||
has_method = false
|
||||
ast.Fn{}, ast.Type(0)
|
||||
}
|
||||
if embed_type != 0 {
|
||||
is_method_from_embed = true
|
||||
call_expr.from_embed_type = embed_type
|
||||
}
|
||||
}
|
||||
if left_type_sym.kind == .aggregate {
|
||||
|
@ -3379,24 +3375,16 @@ pub fn (mut c Checker) selector_expr(mut node ast.SelectorExpr) ast.Type {
|
|||
field = f
|
||||
} else {
|
||||
// look for embedded field
|
||||
if sym.info is ast.Struct {
|
||||
mut found_fields := []ast.StructField{}
|
||||
mut embed_of_found_fields := []ast.Type{}
|
||||
for embed in sym.info.embeds {
|
||||
embed_sym := c.table.get_type_symbol(embed)
|
||||
if f := c.table.find_field(embed_sym, field_name) {
|
||||
found_fields << f
|
||||
embed_of_found_fields << embed
|
||||
}
|
||||
}
|
||||
if found_fields.len == 1 {
|
||||
field = found_fields[0]
|
||||
has_field = true
|
||||
node.from_embed_type = embed_of_found_fields[0]
|
||||
} else if found_fields.len > 1 {
|
||||
c.error('ambiguous field `$field_name`', node.pos)
|
||||
mut embed_type := ast.Type(0)
|
||||
field, embed_type = c.table.find_field_from_embeds(sym, field_name) or {
|
||||
if err.msg != '' {
|
||||
c.error(err.msg, node.pos)
|
||||
}
|
||||
has_field = false
|
||||
ast.StructField{}, ast.Type(0)
|
||||
}
|
||||
node.from_embed_type = embed_type
|
||||
if sym.kind in [.aggregate, .sum_type] {
|
||||
unknown_field_msg = err.msg
|
||||
}
|
||||
|
@ -3416,24 +3404,16 @@ pub fn (mut c Checker) selector_expr(mut node ast.SelectorExpr) ast.Type {
|
|||
field = f
|
||||
} else {
|
||||
// look for embedded field
|
||||
if gs.info is ast.Struct {
|
||||
mut found_fields := []ast.StructField{}
|
||||
mut embed_of_found_fields := []ast.Type{}
|
||||
for embed in gs.info.embeds {
|
||||
embed_sym := c.table.get_type_symbol(embed)
|
||||
if f := c.table.find_field(embed_sym, field_name) {
|
||||
found_fields << f
|
||||
embed_of_found_fields << embed
|
||||
}
|
||||
}
|
||||
if found_fields.len == 1 {
|
||||
field = found_fields[0]
|
||||
has_field = true
|
||||
node.from_embed_type = embed_of_found_fields[0]
|
||||
} else if found_fields.len > 1 {
|
||||
c.error('ambiguous field `$field_name`', node.pos)
|
||||
mut embed_type := ast.Type(0)
|
||||
field, embed_type = c.table.find_field_from_embeds(sym, field_name) or {
|
||||
if err.msg != '' {
|
||||
c.error(err.msg, node.pos)
|
||||
}
|
||||
has_field = false
|
||||
ast.StructField{}, ast.Type(0)
|
||||
}
|
||||
node.from_embed_type = embed_type
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3625,7 +3625,7 @@ fn (mut g Gen) selector_expr(node ast.SelectorExpr) {
|
|||
g.write('.data)')
|
||||
}
|
||||
// struct embedding
|
||||
if sym.info is ast.Struct {
|
||||
if sym.info is ast.Struct || sym.info is ast.Aggregate {
|
||||
if node.from_embed_type != 0 {
|
||||
embed_sym := g.table.get_type_symbol(node.from_embed_type)
|
||||
embed_name := embed_sym.embed_name()
|
||||
|
|
|
@ -157,3 +157,68 @@ fn test_embed_method_generic() {
|
|||
mut app := App{}
|
||||
assert embed_method_generic(app)
|
||||
}
|
||||
|
||||
type Piece = King | Queen
|
||||
|
||||
struct Position {
|
||||
x byte
|
||||
y byte
|
||||
}
|
||||
|
||||
enum TeamEnum {
|
||||
black
|
||||
white
|
||||
}
|
||||
|
||||
struct PieceCommonFields {
|
||||
pos Position
|
||||
team TeamEnum
|
||||
}
|
||||
|
||||
fn (p PieceCommonFields) get_pos() Position {
|
||||
return p.pos
|
||||
}
|
||||
|
||||
struct King {
|
||||
PieceCommonFields
|
||||
}
|
||||
|
||||
struct Queen {
|
||||
PieceCommonFields
|
||||
}
|
||||
|
||||
fn (piece Piece) position() Position {
|
||||
mut pos := Position{}
|
||||
match piece {
|
||||
King, Queen { pos = piece.pos }
|
||||
}
|
||||
return pos
|
||||
}
|
||||
|
||||
fn (piece Piece) get_position() Position {
|
||||
mut pos := Position{}
|
||||
match piece {
|
||||
King, Queen { pos = piece.get_pos() }
|
||||
}
|
||||
return pos
|
||||
}
|
||||
|
||||
fn test_match_aggregate_field() {
|
||||
piece := Piece(King{
|
||||
pos: Position{1, 8}
|
||||
team: .black
|
||||
})
|
||||
pos := piece.position()
|
||||
assert pos.x == 1
|
||||
assert pos.y == 8
|
||||
}
|
||||
|
||||
fn test_match_aggregate_method() {
|
||||
piece := Piece(King{
|
||||
pos: Position{1, 8}
|
||||
team: .black
|
||||
})
|
||||
pos := piece.get_position()
|
||||
assert pos.x == 1
|
||||
assert pos.y == 8
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue