ネルダーミード法を使った救解

ChatGPTにネルダーミード法を使った関数の最適解を求めてもらった

あってるかは後日確認するつもり

% 最小化する関数
func = @(x) (x(1) - 3)^2 + (x(2) - 2)^2;

% 初期点
x0 = [0; 0];

% 収束許容誤差
tol = 1e-6;

% 最大繰り返し回数
max_iter = 200;

% ネルダー・ミード法を実行
[x_opt, fval] = nelder_mead(func, x0, tol, max_iter);

fprintf('最適解: (%f, %f)\n', x_opt(1), x_opt(2));
fprintf('関数値: %f\n', fval);

function [x_opt, fval] = nelder_mead(func, x0, tol, max_iter)
    % 初期化
    n = length(x0); % 次元数
    alpha = 1; % 反射係数
    gamma = 2; % 展開係数
    rho = 0.5; % 収縮係数
    sigma = 0.5; % 縮小係数
    
    % 初期シンプルックスの生成
    simplex = zeros(n, n + 1);
    simplex(:,1) = x0;
    for i = 1:n
        x = x0;
        x(i) = x(i) + 0.05; % 初期の小さな変化
        simplex(:, i+1) = x;
    end
    
    % 初期の関数値を評価
    fval = zeros(1, n + 1);
    for i = 1:n + 1
        fval(i) = feval(func, simplex(:, i));
    end
    
    iter = 0;
    
    while iter < max_iter
        % 順序付け
        [fval, idx] = sort(fval);
        simplex = simplex(:, idx);
        
        % 収束判定
        if max(abs(fval(1) - fval(2:end))) < tol
            break;
        end
        
        % 重心の計算
        x_bar = mean(simplex(:, 1:n), 2);
        
        % 反射
        x_r = x_bar + alpha * (x_bar - simplex(:, end));
        f_r = feval(func, x_r);
        
        if f_r < fval(1)
            % 展開
            x_e = x_bar + gamma * (x_r - x_bar);
            f_e = feval(func, x_e);
            if f_e < f_r
                simplex(:, end) = x_e;
                fval(end) = f_e;
            else
                simplex(:, end) = x_r;
                fval(end) = f_r;
            end
        else
            if f_r < fval(end-1)
                simplex(:, end) = x_r;
                fval(end) = f_r;
            else
                % 収縮
                if f_r < fval(end)
                    x_c = x_bar + rho * (x_r - x_bar);
                    f_c = feval(func, x_c);
                    if f_c <= f_r
                        simplex(:, end) = x_c;
                        fval(end) = f_c;
                    else
                        % 縮小
                        for i = 2:n + 1
                            simplex(:, i) = simplex(:, 1) + sigma * (simplex(:, i) - simplex(:, 1));
                            fval(i) = feval(func, simplex(:, i));
                        end
                    end
                else
                    x_c = x_bar + rho * (simplex(:, end) - x_bar);
                    f_c = feval(func, x_c);
                    if f_c < fval(end)
                        simplex(:, end) = x_c;
                        fval(end) = f_c;
                    else
                        % 縮小
                        for i = 2:n + 1
                            simplex(:, i) = simplex(:, 1) + sigma * (simplex(:, i) - simplex(:, 1));
                            fval(i) = feval(func, simplex(:, i));
                        end
                    end
                end
            end
        end
        iter = iter + 1;
    end
    
    x_opt = simplex(:, 1);
end

コメント

タイトルとURLをコピーしました