checker: fix generic fn infering fn type argument (fix #14243) (#14256)

yuyi 2022-05-02 21:17:27 +08:00 committed by Jef Roosens
parent 21555fb03a
commit 374b6927bc
Signed by: Jef Roosens
GPG Key ID: B75D4F293C7052DB
2 changed files with 52 additions and 2 deletions

View File

@ -627,7 +627,9 @@ pub fn (mut c Checker) infer_fn_generic_types(func ast.Fn, mut node ast.CallExpr
param_elem_info = param_elem_sym.info as ast.Array
param_elem_sym = c.table.sym(param_elem_info.elem_type)
} else {
to_set = arg_elem_info.elem_type
if param_elem_sym.name == gt_name {
typ = arg_elem_info.elem_type
}
break
}
}
@ -644,7 +646,9 @@ pub fn (mut c Checker) infer_fn_generic_types(func ast.Fn, mut node ast.CallExpr
param_elem_info = param_elem_sym.info as ast.ArrayFixed
param_elem_sym = c.table.sym(param_elem_info.elem_type)
} else {
to_set = arg_elem_info.elem_type
if param_elem_sym.name == gt_name {
typ = arg_elem_info.elem_type
}
break
}
}
@ -659,6 +663,21 @@ pub fn (mut c Checker) infer_fn_generic_types(func ast.Fn, mut node ast.CallExpr
&& c.table.sym(param_map_info.value_type).name == gt_name {
typ = arg_map_info.value_type
}
} else if arg_sym.kind == .function && param_type_sym.kind == .function {
arg_type_func := (arg_sym.info as ast.FnType).func
param_type_func := (param_type_sym.info as ast.FnType).func
if param_type_func.params.len == arg_type_func.params.len {
for n, fn_param in param_type_func.params {
if fn_param.typ.has_flag(.generic)
&& c.table.sym(fn_param.typ).name == gt_name {
typ = arg_type_func.params[n].typ
}
}
if param_type_func.return_type.has_flag(.generic)
&& c.table.sym(param_type_func.return_type).name == gt_name {
typ = arg_type_func.return_type
}
}
} else if arg_sym.kind in [.struct_, .interface_, .sum_type] {
mut generic_types := []ast.Type{}
mut concrete_types := []ast.Type{}

View File

@ -0,0 +1,31 @@
fn test_generic_fn_infer_fn_type_argument() {
to_r := fn (x int) rune {
return [`😺`, `😸`, `😹`, `😻`, `😾`][x - 1]
}
to_f64 := fn (x int) f64 {
return f64(x) + 0.123
}
to_s := fn (x int) string {
return ['One', 'Two', 'Three', 'Four', 'Five'][x - 1]
}
items := [1, 2, 3, 4, 5]
ret_r := fmap(to_r, items)
println('${ret_r.map(rune(it))}')
assert '${ret_r.map(rune(it))}' == '[`😺`, `😸`, `😹`, `😻`, `😾`]'
// returns random same number for every item in array
ret_f64 := fmap(to_f64, items)
println(ret_f64)
assert ret_f64 == [1.123, 2.123, 3.123, 4.123, 5.123]
ret_s := fmap(to_s, items)
println(ret_s)
assert ret_s == ['One', 'Two', 'Three', 'Four', 'Five']
}
// [noah04 #14214] code
fn fmap<I, O>(func fn (I) O, list []I) []O {
return []O{len: list.len, init: func(list[it])}
}