checker: type inference over a generic type should compile (#13824)

pull/13899/head
Vincenzo Palazzo 2022-04-01 18:31:27 +02:00 committed by GitHub
parent 9d2529b611
commit d7817863c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 28 additions and 4 deletions

View File

@ -327,7 +327,9 @@ pub fn (mut rng PRNG) choose<T>(array []T, k int) ?[]T {
}
mut results := []T{len: k}
mut indices := []int{len: n, init: it}
rng.shuffle(mut indices) ?
// TODO: see why exactly it is necessary to enfoce the type here in Checker.infer_fn_generic_types
// (v errors with: `inferred generic type T is ambiguous: got int, expected string`, when <int> is missing)
rng.shuffle<int>(mut indices) ?
for i in 0 .. k {
results[i] = array[indices[i]]
}

View File

@ -535,7 +535,7 @@ pub fn (mut c Checker) infer_fn_generic_types(func ast.Fn, mut node ast.CallExpr
for i, param in func.params {
mut to_set := ast.void_type
// resolve generic struct receiver
if i == 0 && node.is_method && param.typ.has_flag(.generic) {
if node.is_method && param.typ.has_flag(.generic) {
sym := c.table.sym(node.receiver_type)
match sym.info {
ast.Struct, ast.Interface, ast.SumType {

View File

@ -823,10 +823,12 @@ pub fn (mut c Checker) infix_expr(mut node ast.InfixExpr) ast.Type {
left_gen_type := c.unwrap_generic(left_type)
gen_sym := c.table.sym(left_gen_type)
need_overload := gen_sym.kind in [.struct_, .interface_]
if need_overload && !gen_sym.has_method('<') && node.op in [.ge, .le] {
if need_overload && !gen_sym.has_method_with_generic_parent('<')
&& node.op in [.ge, .le] {
c.error('cannot use `$node.op` as `<` operator method is not defined',
left_right_pos)
} else if need_overload && !gen_sym.has_method('<') && node.op == .gt {
} else if need_overload && !gen_sym.has_method_with_generic_parent('<')
&& node.op == .gt {
c.error('cannot use `>` as `<=` operator method is not defined', left_right_pos)
}
} else if left_type in ast.integer_type_idxs && right_type in ast.integer_type_idxs {

View File

@ -0,0 +1,20 @@
import datatypes
struct KeyVal<T> {
key string
val T
}
fn (a KeyVal<T>) == (b KeyVal<T>) bool {
return a.key == b.key
}
fn (a KeyVal<T>) < (b KeyVal<T>) bool {
return a.key < b.key
}
fn main() {
mut bst := datatypes.BSTree<KeyVal<int>>{}
bst.insert(KeyVal<int>{key: "alibaba", val: 12})
println(bst.in_order_traversal())
}