all: implement iterators in for loops (#7867)

pull/7877/head
spaceface 2021-01-05 01:06:44 +01:00 committed by GitHub
parent 82162b8ff8
commit efb80bdffd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 101 additions and 23 deletions

View File

@ -2824,32 +2824,53 @@ fn (mut c Checker) stmt(node ast.Stmt) {
}
} else {
sym := c.table.get_type_symbol(typ)
if sym.kind == .map && !(node.key_var.len > 0 && node.val_var.len > 0) {
c.error('declare a key and a value variable when ranging a map: `for key, val in map {`\n' +
'use `_` if you do not need the variable', node.pos)
}
if node.key_var.len > 0 {
key_type := match sym.kind {
.map { sym.map_info().key_type }
else { table.int_type }
}
node.key_type = key_type
node.scope.update_var_type(node.key_var, key_type)
}
mut value_type := c.table.value_type(typ)
if value_type == table.void_type || typ.has_flag(.optional) {
if typ != table.void_type {
c.error('for in: cannot index `${c.table.type_to_str(typ)}`',
if sym.kind == .struct_ {
// iterators
next_fn := sym.find_method('next') or {
c.error('a struct must have a `next()` method to be an iterator',
node.cond.position())
return
}
if !next_fn.return_type.has_flag(.optional) {
c.error('iterator method `next()` must return an optional', node.cond.position())
}
// the receiver
if next_fn.params.len != 1 {
c.error('iterator method `next()` must have 0 parameters', node.cond.position())
}
val_type := next_fn.return_type.clear_flag(.optional)
node.cond_type = typ
node.kind = sym.kind
node.val_type = val_type
node.scope.update_var_type(node.val_var, val_type)
} else {
if sym.kind == .map && !(node.key_var.len > 0 && node.val_var.len > 0) {
c.error('declare a key and a value variable when ranging a map: `for key, val in map {`\n' +
'use `_` if you do not need the variable', node.pos)
}
if node.key_var.len > 0 {
key_type := match sym.kind {
.map { sym.map_info().key_type }
else { table.int_type }
}
node.key_type = key_type
node.scope.update_var_type(node.key_var, key_type)
}
mut value_type := c.table.value_type(typ)
if value_type == table.void_type || typ.has_flag(.optional) {
if typ != table.void_type {
c.error('for in: cannot index `${c.table.type_to_str(typ)}`',
node.cond.position())
}
}
if node.val_is_mut {
value_type = value_type.to_ptr()
}
node.cond_type = typ
node.kind = sym.kind
node.val_type = value_type
node.scope.update_var_type(node.val_var, value_type)
}
if node.val_is_mut {
value_type = value_type.to_ptr()
}
node.cond_type = typ
node.kind = sym.kind
node.val_type = value_type
node.scope.update_var_type(node.val_var, value_type)
}
c.check_loop_label(node.label, node.pos)
c.stmts(node.stmts)

View File

@ -1305,6 +1305,27 @@ fn (mut g Gen) for_in(it ast.ForInStmt) {
g.expr(it.cond)
g.writeln('.str[$i];')
}
} else if it.kind == .struct_ {
cond_type_sym := g.table.get_type_symbol(it.cond_type)
next_fn := cond_type_sym.find_method('next') or {
verror('`next` method not found')
return
}
ret_typ := next_fn.return_type
g.writeln('while (1) {')
t := g.new_tmp_var()
receiver_styp := g.typ(next_fn.params[0].typ)
fn_name := receiver_styp.replace_each(['*', '', '.', '__']) + '_next'
g.write('\t${g.typ(ret_typ)} $t = ${fn_name}(')
if !it.cond_type.is_ptr() {
g.write('&')
}
g.expr(it.cond)
g.writeln(');')
g.writeln('\tif (!${t}.ok) { break; }')
val := if it.val_var in ['', '_'] { g.new_tmp_var() } else { it.val_var }
val_styp := g.typ(it.val_type)
g.writeln('\t$val_styp $val = *($val_styp*)${t}.data;')
} else {
s := g.table.type_to_str(it.cond_type)
g.error('for in: unhandled symbol `$it.cond` of type `$s`', it.pos)

View File

@ -0,0 +1,36 @@
struct Doubler {
mut:
val int
until int
}
fn (mut it Doubler) next() ?int {
v := it.val
if v > it.until {
return none
}
it.val *= 2
return v
}
fn doubler(start int, until int) Doubler {
return Doubler{start, until}
}
fn test_for_in_iterator() {
mut d := doubler(5, 30)
mut vals := []int{}
for val in d {
vals << val
}
assert vals == [5, 10, 20]
}
fn test_for_in_empty_iterator() {
mut d := doubler(5, 2)
mut vals := []int{}
for val in d {
vals << val
}
assert vals == []
}