checker: fix generics fn return generics fn type (fix #10085) (#10088)

pull/10099/head
yuyi 2021-05-13 17:26:13 +08:00 committed by GitHub
parent 143c3d4bb4
commit 14b7ce0f04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 70 additions and 1 deletions

View File

@ -426,7 +426,7 @@ pub mut:
} }
pub struct FnType { pub struct FnType {
pub: pub mut:
is_anon bool is_anon bool
has_decl bool has_decl bool
func Fn func Fn

View File

@ -508,6 +508,15 @@ pub fn (mut c Checker) infer_fn_generic_types(f ast.Fn, mut call_expr ast.CallEx
if param.typ.has_flag(.generic) && param_type_sym.name == gt_name { if param.typ.has_flag(.generic) && param_type_sym.name == gt_name {
to_set = c.table.mktyp(arg.typ) to_set = c.table.mktyp(arg.typ)
mut sym := c.table.get_type_symbol(arg.typ)
if mut sym.info is ast.FnType {
if !sym.info.is_anon {
sym.info.func.name = ''
idx := c.table.find_or_register_fn_type(c.mod, sym.info.func,
true, false)
to_set = ast.new_type(idx).derive(arg.typ)
}
}
if arg.expr.is_auto_deref_var() { if arg.expr.is_auto_deref_var() {
to_set = to_set.deref() to_set = to_set.deref()
} }

View File

@ -0,0 +1,60 @@
fn neg(a int) int {
return -a
}
fn normal_v1(func fn (int) int) fn (int) int {
assert typeof(func).name == typeof(neg).name
return func
}
fn normal_v2(func fn (int) int) fn (int) int {
f := func
assert typeof(f).name == typeof(neg).name
return f
}
fn generic_v1<T>(func T) T {
assert T.name == typeof(neg).name
assert typeof(func).name == typeof(neg).name
return func
}
fn generic_v2<T>(func T) T {
assert T.name == typeof(neg).name
f := func
assert typeof(f).name == typeof(neg).name
return f
}
fn mixed_v1<T>(func T) fn (int) int {
assert T.name == typeof(neg).name
assert typeof(func).name == typeof(neg).name
return func
}
fn mixed_v2<T>(func T) fn (int) int {
assert T.name == typeof(neg).name
f := func
assert typeof(f).name == typeof(neg).name
return f
}
fn test_generics_with_generics_fn_return_type() {
mut f := neg
assert f(1) == -1
f = normal_v1(neg)
assert f(2) == -2
f = normal_v2(neg)
assert f(3) == -3
f = generic_v1(neg)
assert f(4) == -4
f = generic_v2(neg)
assert f(5) == -5
f = mixed_v1(neg)
assert f(6) == -6
f = mixed_v2(neg)
assert f(7) == -7
}