You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

66 lines
2.4 KiB

import numpy as np
import matplotlib.pyplot as plt
#GA遗传算法
def rastrigin(x):
A = 10
n = len(x)
return A*n + np.sum(x**2 - A*np.cos(2*np.pi*x))
def create_population(pop_size, n):
return np.random.uniform(-5.12, 5.12, size=(pop_size, n))
def evaluate_population(population):
return np.array([rastrigin(x) for x in population])
def select_parents(population, fitness, num_parents):
sorted_indices = np.argsort(fitness)
return population[sorted_indices[:num_parents], :]
def crossover(parents, offspring_size):
offspring = np.empty(offspring_size)
crossover_point = offspring_size[1] // 2
for k in range(offspring_size[0]):
parent1_idx = k % parents.shape[0]
parent2_idx = (k+1) % parents.shape[0]
offspring[k, :crossover_point] = parents[parent1_idx, :crossover_point]
offspring[k, crossover_point:] = parents[parent2_idx, crossover_point:]
return offspring
def mutation(offspring_crossover):
mutation_rate = 0.01
for idx in range(offspring_crossover.shape[0]):
for gene_idx in range(offspring_crossover.shape[1]):
if np.random.rand() < mutation_rate:
offspring_crossover[idx, gene_idx] = np.random.uniform(-5.12, 5.12)
return offspring_crossover
def genetic_algorithm(pop_size, n, num_generations):
population = create_population(pop_size, n)
best_fitness = np.inf
best_individual = None
best_fitness_record = []
for i in range(num_generations):
fitness = evaluate_population(population)
if np.min(fitness) < best_fitness:
best_fitness = np.min(fitness)
best_individual = population[np.argmin(fitness), :]
parents = select_parents(population, fitness, num_parents=2)
offspring_crossover = crossover(parents, offspring_size=(pop_size-2, n))
offspring_mutation = mutation(offspring_crossover)
population[2:, :] = offspring_mutation
population[:2, :] = parents
best_fitness_record.append(best_fitness)
return best_individual, best_fitness_record
best_individual, best_fitness_record = genetic_algorithm(pop_size=50, n=3, num_generations=1000)
plt.plot(best_fitness_record)
plt.title('Genetic Algorithm for Rastrigin Function')
plt.xlabel('Generation')
plt.ylabel('Best Fitness')
plt.show()