diff --git a/neat/genome.py b/neat/genome.py index f32ed46e..61a433cc 100644 --- a/neat/genome.py +++ b/neat/genome.py @@ -310,6 +310,7 @@ def mutate_add_node(self, config): conn_to_split = choice(list(self.connections.values())) new_node_id = config.get_new_node_key(self.nodes) ng = self.create_node(config, new_node_id) + ng.bias = 0.0 self.nodes[new_node_id] = ng # Disable this connection and create two new connections joining its nodes via @@ -327,23 +328,27 @@ def add_connection(self, config, input_key, output_key, weight, enabled): assert isinstance(output_key, int) assert output_key >= 0 assert isinstance(enabled, bool) - key = (input_key, output_key) - connection = config.connection_gene_type(key) - connection.init_attributes(config) + connection = self.create_connection(config, input_key, output_key) connection.weight = weight connection.enabled = enabled - self.connections[key] = connection + self.connections[connection.key] = connection def mutate_add_connection(self, config): """ Attempt to add a new connection, the only restriction being that the output node cannot be one of the network input pins. """ + # TODO: Maybe keep trying to add connection until successful, so that the probability holds. + # First check if new connections are even possible (costly?) to avoid infinite loop. possible_outputs = list(self.nodes) out_node = choice(possible_outputs) possible_inputs = possible_outputs + config.input_keys in_node = choice(possible_inputs) + + # Don't use output nodes as inputs + if in_node in config.output_keys: + return # Don't duplicate connections. key = (in_node, out_node) @@ -353,13 +358,6 @@ def mutate_add_connection(self, config): self.connections[key].enabled = True return - # Don't allow connections between two output nodes - if in_node in config.output_keys and out_node in config.output_keys: - return - - # No need to check for connections between input nodes: - # they cannot be the output end of a connection (see above). - # For feed-forward networks, avoid creating cycles. if config.feed_forward and creates_cycle(list(self.connections), key): return @@ -374,23 +372,25 @@ def mutate_delete_node(self, config): return -1 del_key = choice(available_nodes) + del self.nodes[del_key] - connections_to_delete = set() - for k, v in self.connections.items(): - if del_key in v.key: - connections_to_delete.add(v.key) - - for key in connections_to_delete: + # Delete connections of chosen node + for key in [k for k in self.connections if del_key in k]: del self.connections[key] - del self.nodes[del_key] - return del_key def mutate_delete_connection(self): - if self.connections: - key = choice(list(self.connections.keys())) - del self.connections[key] + # Do nothing if there are no connections + if not self.connections: + return -1 + + del_key = choice(list(self.connections)) + del self.connections[del_key] + + # TODO: Also delete nodes if left floating? Iterating may be needed. + + return del_key def distance(self, other, config): """