WorldCoder, a Model-Based LLM Agent: Building World Models by Writing Code and Interacting with the Environment

Cornell University
How could an AI experiment with a new environment, and then very quickly learn how it works and achieve a range of novel goals? Our agent programs a world model in code; explores by inventing reward functions it optimistically thinks it can achieve; then plans to achieve novel goals. On certain videogame and robot planning tasks, it learns 10,000x faster than deep RL, at a fraction of the LLM cost of ReAct.

Learning a novel variant of Sokoban with teleports

WorldCoder learns how to use the teleports and uses it to push the boxes to targets faster.

Learned Transition Function

class Entity:
        def __init__(self, x, y):
            self.name = self.__class__.__name__
            self.x = x
            self.y = y
        def __repr__(self):
            return f'{self.name}({self.x}, {self.y})'
        def __eq__(self, other):
            return self.name == other.name and self.x == other.x and self.y == other.y
        def __hash__(self):
            return hash((self.name, self.x, self.y))
    class Player(Entity): pass
    class Wall(Entity): pass
    class Box(Entity): pass
    class Teleporter(Entity): pass
    class Target(Entity): pass
    def get_entities_by_name(entities, name):
        return [ entity for entity in entities if entity.name == name ]
    def get_entities_by_position(entities, x, y):
        return [ entity for entity in entities if entity.x == x and entity.y == y ]
    def move_entity(entity, dx, dy, state):
        entity.x += dx
        entity.y += dy
        if any(isinstance(e, Wall) for e in get_entities_by_position(state, entity.x, entity.y)):
            # can't move into a Wall
            entity.x -= dx
            entity.y -= dy
        elif dx != 0 and any(isinstance(e, Box) for e in get_entities_by_position(state, entity.x, entity.y)):
            # push the box
            box = next(e for e in get_entities_by_position(state, entity.x, entity.y) if isinstance(e, Box))
            push_box = move_entity(box, dx, 0, state) if dx !=0 else move_entity(box, 0, dy, state)
            if not push_box:
                # can't move the box into a Wall
                entity.x -= dx
                entity.y -= dy
        return True
    def can_move_to(entity, dx, dy, state):
        new_x, new_y = entity.x + dx, entity.y + dy
        entities_in_destination = get_entities_by_position(state, new_x, new_y)
        is_destination_blocked = any(isinstance(e, Wall) for e in entities_in_destination)
        would_box_be_pushed_off_grid = (new_x < 0 or new_y < 0 or new_x > 6 or new_y > 6)
        return not (is_destination_blocked or would_box_be_pushed_off_grid)
    def move_entity_if_possible(entity, dx, dy, state):
        if can_move_to(entity, dx, dy, state):
            entity.x += dx
            entity.y += dy
            # return True if successful
            return True
        # return False if unsuccesful
        return False
    def transition(state, action):
        player = get_entities_by_name(state, "Player")[0]
        dx, dy = 0, 0
        if action == "move up":
            dx, dy = 0, -1
        elif action == "move down":
            dx, dy = 0, 1
        elif action == "move left":
            dx, dy = -1, 0
        elif action == "move right":
            dx, dy = 1, 0
        if action.startswith('move'):
            entities_in_destination = get_entities_by_position(state, player.x + dx, player.y + dy)
            if any(isinstance(e, Box) for e in entities_in_destination):  # if there is a box in the destination
                box = next(e for e in entities_in_destination if isinstance(e, Box))
                entities_in_box_destination = get_entities_by_position(state, box.x + dx, box.y + dy)
                if not any(isinstance(e, Box) for e in entities_in_box_destination) and not any(isinstance(e, Wall) for e in entities_in_box_destination):  # Box cannot be pushed into another box or wall.
                    box_pushed = move_entity_if_possible(box, dx, dy, state)
                    if box_pushed:  # if box was successfully pushed, move player
                        move_entity_if_possible(player, dx, dy, state)
            else:  # if there is no box in the destination, move player
                move_entity_if_possible(player, dx, dy, state)
        elif action == "activate":
            activate_teleporter(player, state)
        return state
    def activate_teleporter(entity, state):
        current_position_entities = get_entities_by_position(state, entity.x, entity.y)
        if any(isinstance(e, Teleporter) for e in current_position_entities):
            # can only activate if on a Teleporter
            teleporters = get_entities_by_name(state, 'Teleporter')
            other_teleporter = next(t for t in teleporters if t.x != entity.x or t.y != entity.y)
            other_teleporter_x = other_teleporter.x
            other_teleporter_y = other_teleporter.y
            other_teleported_position_entities = get_entities_by_position(state, other_teleporter_x, other_teleporter_y)
            if any(isinstance(e, Box) for e in other_teleported_position_entities):  # if there is a box on the other teleporter
                box_on_teleporter = next(e for e in other_teleported_position_entities if isinstance(e, Box))
                push_x = box_on_teleporter.x - entity.x
                push_y = box_on_teleporter.y - entity.y
                # Check if the box can be moved in the pushed direction
                if can_move_to(box_on_teleporter, push_x, push_y, state):
                    move_entity(box_on_teleporter, push_x, push_y, state)  # push the box
                    entity.x = other_teleporter.x
                    entity.y = other_teleporter.y
                    return True
            else:
                entity.x = other_teleporter.x
                entity.y = other_teleporter.y
                return True
        return False


Learned Reward Function

def reward_func(state, action, next_state):
        targets_state = get_entities_by_name(state, 'Target')
        boxes_state = get_entities_by_name(state, 'Box')
        targets_next_state = get_entities_by_name(next_state, 'Target')
        boxes_next_state = get_entities_by_name(next_state, 'Box')
        # Count the number of boxes placed on targets in the state and next_state.
        boxes_on_targets_state = sum([any(box.x == target.x and box.y == target.y for box in boxes_state) for target in targets_state])
        boxes_on_targets_next_state = sum([any(box.x == target.x and box.y == target.y for box in boxes_next_state) for target in targets_next_state])
        # Determine the reward based on an increase in count of boxes on targets
        if boxes_on_targets_next_state > boxes_on_targets_state:
            reward = 0.9  # box was moved onto a target
        else:
            reward = -0.1  # no box was moved onto a target
        # Determine if the game is finished
        done = all(any(box.x == target.x and box.y == target.y for box in boxes_next_state) for target in targets_next_state)    
        if done:
            reward += 10  # Adding 10 bonus for completing the game
        return reward, done

Overall agent architecture.


Qualitative comparison of WorldCoder against deep model-based RL and LLM agents. (∗LLM agents do not update their world model).

BibTeX

@inproceedings{tang2024worldcoder,
  title={Worldcoder, a model-based llm agent: Building world models by writing code and interacting with the environment},
  author={Tang, Hao and Key, Darren and Ellis, Kevin},
  booktitle = {Advances in Neural Information Processing Systems},
  year={2024}
}