From 8862c3af0fb59b5e4644dff288591ce6f264f759 Mon Sep 17 00:00:00 2001 From: yuyi Date: Tue, 14 Sep 2021 20:56:12 +0800 Subject: [PATCH] all: implement `if foo in [Foo1, Foo2, Foo3]` (#11486) --- vlib/v/checker/checker.v | 17 ++++++++++------- vlib/v/gen/c/infix_expr.v | 29 +++++++++++++++++++++++++++++ vlib/v/parser/expr.v | 8 ++++++++ vlib/v/parser/parser.v | 13 +++++++++++++ vlib/v/tests/in_expression_test.v | 17 +++++++++++++++++ 5 files changed, 77 insertions(+), 7 deletions(-) diff --git a/vlib/v/checker/checker.v b/vlib/v/checker/checker.v index 4ad49c3b87..ec32488af5 100644 --- a/vlib/v/checker/checker.v +++ b/vlib/v/checker/checker.v @@ -1356,11 +1356,12 @@ pub fn (mut c Checker) infix_expr(mut node ast.InfixExpr) ast.Type { .key_in, .not_in { match right_final.kind { .array { - elem_type := right_final.array_info().elem_type - // if left_default.kind != right_sym.kind { - c.check_expected(left_type, elem_type) or { - c.error('left operand to `$node.op` does not match the array element type: $err.msg', - left_right_pos) + if left_sym.kind !in [.sum_type, .interface_] { + elem_type := right_final.array_info().elem_type + c.check_expected(left_type, elem_type) or { + c.error('left operand to `$node.op` does not match the array element type: $err.msg', + left_right_pos) + } } } .map { @@ -4522,8 +4523,10 @@ pub fn (mut c Checker) array_init(mut node ast.ArrayInit) ast.Type { c.expected_type = elem_type continue } - c.check_expected(typ, elem_type) or { - c.error('invalid array element: $err.msg', expr.position()) + if expr !is ast.TypeNode { + c.check_expected(typ, elem_type) or { + c.error('invalid array element: $err.msg', expr.position()) + } } } if node.is_fixed { diff --git a/vlib/v/gen/c/infix_expr.v b/vlib/v/gen/c/infix_expr.v index 71c4d5a7ea..de9bff7c27 100644 --- a/vlib/v/gen/c/infix_expr.v +++ b/vlib/v/gen/c/infix_expr.v @@ -313,6 +313,15 @@ fn (mut g Gen) infix_expr_cmp_op(node ast.InfixExpr) { } } +fn (mut g Gen) infix_expr_in_sumtype_interface_array(infix_exprs []ast.InfixExpr) { + for i in 0 .. infix_exprs.len { + g.infix_expr_is_op(infix_exprs[i]) + if i != infix_exprs.len - 1 { + g.write(' || ') + } + } +} + // infix_expr_in_op generates code for `in` and `!in` fn (mut g Gen) infix_expr_in_op(node ast.InfixExpr) { left := g.unwrap(node.left_type) @@ -321,6 +330,26 @@ fn (mut g Gen) infix_expr_in_op(node ast.InfixExpr) { g.write('!') } if right.unaliased_sym.kind == .array { + if left.sym.kind in [.sum_type, .interface_] { + if mut node.right is ast.ArrayInit { + if node.right.exprs.len > 0 { + mut infix_exprs := []ast.InfixExpr{} + for i in 0 .. node.right.exprs.len { + infix_exprs << ast.InfixExpr{ + op: .key_is + left: node.left + left_type: node.left_type + right: node.right.exprs[i] + right_type: node.right.expr_types[i] + } + } + g.write('(') + g.infix_expr_in_sumtype_interface_array(infix_exprs) + g.write(')') + return + } + } + } if mut node.right is ast.ArrayInit { if node.right.exprs.len > 0 { // `a in [1,2,3]` optimization => `a == 1 || a == 2 || a == 3` diff --git a/vlib/v/parser/expr.v b/vlib/v/parser/expr.v index a126c959d1..c6c329e4d3 100644 --- a/vlib/v/parser/expr.v +++ b/vlib/v/parser/expr.v @@ -472,7 +472,15 @@ fn (mut p Parser) infix_expr(left ast.Expr) ast.Expr { if op in [.key_is, .not_is] { p.expecting_type = true } + is_key_in := op in [.key_in, .not_in] + if is_key_in { + p.inside_in_array = true + } right = p.expr(precedence) + if is_key_in { + p.inside_in_array = false + } + p.expecting_type = prev_expecting_type if p.pref.is_vet && op in [.key_in, .not_in] && right is ast.ArrayInit && (right as ast.ArrayInit).exprs.len == 1 { diff --git a/vlib/v/parser/parser.v b/vlib/v/parser/parser.v index 827fd4fbb1..34c471f853 100644 --- a/vlib/v/parser/parser.v +++ b/vlib/v/parser/parser.v @@ -44,6 +44,7 @@ mut: inside_unsafe_fn bool inside_str_interp bool inside_array_lit bool + inside_in_array bool or_is_handled bool // ignore `or` in this expression builtin_mod bool // are we in the `builtin` module? mod string // current module name @@ -2233,6 +2234,18 @@ pub fn (mut p Parser) name_expr() ast.Expr { // JS. function call with more than 1 dot node = p.call_expr(language, mod) } else { + if p.inside_in_array && ((lit0_is_capital && !known_var && language == .v) + || (p.peek_tok.kind == .dot && p.peek_token(2).lit.len > 0 + && p.peek_token(2).lit[0].is_capital()) + || p.table.find_type_idx(p.mod + '.' + p.tok.lit) > 0) { + type_pos := p.tok.position() + typ := p.parse_type() + return ast.TypeNode{ + typ: typ + pos: type_pos + } + } + ident := p.parse_ident(language) node = ident if p.inside_defer { diff --git a/vlib/v/tests/in_expression_test.v b/vlib/v/tests/in_expression_test.v index 8d331f0bb1..63a3730a63 100644 --- a/vlib/v/tests/in_expression_test.v +++ b/vlib/v/tests/in_expression_test.v @@ -264,3 +264,20 @@ fn test_in_expression_numeric() { assert 1.0625 in f2 assert 3.5 !in f2 } + +struct Foo1 {} + +struct Foo2 {} + +struct Foo3 {} + +type Foo = Foo1 | Foo2 | Foo3 + +fn test_in_sumtype_array() { + foo := Foo(Foo3{}) + + if foo in [Foo1, Foo3] { + println(foo) + assert true + } +}