checker: infer generic type T from matching fn call argument (#6298)

pull/6567/head
Nick Treleaven 2020-10-06 14:34:02 +01:00 committed by GitHub
parent 580fefe63b
commit f7decfe399
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 84 additions and 4 deletions

View File

@ -372,3 +372,20 @@ pub fn (mut c Checker) string_inter_lit(mut node ast.StringInterLiteral) table.T
pub fn (c &Checker) check_sumtype_compatibility(a, b table.Type) bool {
return c.table.sumtype_has_variant(a, b) || c.table.sumtype_has_variant(b, a)
}
pub fn (mut c Checker) infer_fn_types(f table.Fn, mut call_expr ast.CallExpr) {
gt_name := 'T'
mut typ := table.void_type
for i, arg in f.params {
if arg.type_source_name == gt_name {
typ = call_expr.args[i].typ
break
}
}
if typ == table.void_type {
c.error('could not infer generic type `$gt_name` in call to `$f.name`', call_expr.pos)
} else {
c.table.register_fn_gen_type(f.name, typ)
call_expr.generic_type = typ
}
}

View File

@ -56,6 +56,7 @@ mut:
inside_sql bool // to handle sql table fields pseudo variables
cur_orm_ts table.TypeSymbol
error_details []string
generic_funcs []&ast.FnDecl
}
pub fn new_checker(table &table.Table, pref &pref.Preferences) Checker {
@ -80,6 +81,7 @@ pub fn (mut c Checker) check(ast_file ast.File) {
c.stmt(stmt)
}
c.check_scope_vars(c.file.scope)
c.post_process_generic_fns()
}
pub fn (mut c Checker) check_scope_vars(sc &ast.Scope) {
@ -1505,6 +1507,10 @@ pub fn (mut c Checker) call_fn(mut call_expr ast.CallExpr) table.Type {
}
}
}
if f.is_generic && call_expr.generic_type == table.void_type {
// no type arguments given in call, attempt implicit instantiation
c.infer_fn_types(f, mut call_expr)
}
if call_expr.generic_type != table.void_type && f.return_type != 0 { // table.t_type {
// Handle `foo<T>() T` => `foo<int>() int` => return int
return_sym := c.table.get_type_symbol(f.return_type)
@ -3962,10 +3968,16 @@ fn (mut c Checker) fetch_and_verify_orm_fields(info table.Struct, pos token.Posi
return fields
}
fn (mut c Checker) fn_decl(mut node ast.FnDecl) {
c.returns = false
if node.is_generic && c.cur_generic_type == 0 { // need the cur_generic_type check to avoid inf. recursion
// loop thru each generic type and generate a function
fn (mut c Checker) post_process_generic_fns() {
// Loop thru each generic function concrete type.
// Check each specific fn instantiation.
for i in 0 .. c.generic_funcs.len {
mut node := c.generic_funcs[i]
if c.table.fn_gen_types.len == 0 {
// no concrete types, so just skip:
continue
}
// eprintln('>> post_process_generic_fns $c.file.path | $node.name , c.table.fn_gen_types.len: $c.table.fn_gen_types.len')
for gen_type in c.table.fn_gen_types[node.name] {
c.cur_generic_type = gen_type
// sym:=c.table.get_type_symbol(gen_type)
@ -3973,6 +3985,24 @@ fn (mut c Checker) fn_decl(mut node ast.FnDecl) {
c.fn_decl(mut node)
}
c.cur_generic_type = 0
c.generic_funcs[i] = 0
}
// The generic funtions for each file/mod should be
// postprocessed just once in the checker, while the file/mod
// context is still the same.
c.generic_funcs = []
}
fn (mut c Checker) fn_decl(mut node ast.FnDecl) {
c.returns = false
if node.is_generic && c.cur_generic_type == 0 {
// Just remember the generic function for now.
// It will be processed later in c.post_process_generic_fns,
// after all other normal functions are processed.
// This is done so that all generic function calls can
// have a chance to populate c.table.fn_gen_types with
// the correct concrete types.
c.generic_funcs << node
return
}
if node.language == .v && !c.is_builtin_mod {

View File

@ -0,0 +1,33 @@
fn call<T>(v T) {
}
fn simple<T>(p T) T {
return p
}
fn test_infer() {
call(3)
i := 4
r := simple(i)
assert r == 4
}
fn test_explicit_calls_should_also_work() {
call<int>(2)
assert true
simple<int>(5)
assert true
}
//
fn choose4<T>(a, b, c, d T) T {
// NB: a similar construct is used in prime31's via engine
return a
}
fn test_calling_generic_fn_with_many_params() {
x := choose4(1, 2, 3, 4)
assert x == 1
y := choose4<string>('abc', 'xyz', 'def', 'ghi')
assert y == 'abc'
}