忍者ブログ

Memeplexes

プログラミング、3DCGとその他いろいろについて

Q学習 サンプルコード(C#)

ここの問題設定を使ってQ学習のサンプルコードを書いてみました。
言語はC#です。


コード

Program.cs

using System;
using System.Linq;

namespace QLearningDemo
{
    class Program
    {
        const int Iterations = 30;

        private static Random random = new Random(10);
        private static QLearning learning = new QLearning();

        public static void Main()
        {
            initStates();
            train();
            test();
        }

        private static void initStates()
        {
            for (int i = 0; i < 6; i++)
            {
                learning.States.Add(new State());
            }

            learning.AddAction(0, 4, 0);
            learning.AddAction(4, 0, 0);

            learning.AddAction(4, 3, 0);
            learning.AddAction(3, 4, 0);

            learning.AddAction(2, 3, 0);
            learning.AddAction(3, 2, 0);

            learning.AddAction(1, 3, 0);
            learning.AddAction(3, 1, 0);

            learning.AddAction(1, 5, 100);
            learning.AddAction(5, 1, 0);

            learning.AddAction(4, 5, 100);
            learning.AddAction(5, 4, 0);

            learning.AddAction(5, 5, 100);

            learning.GoalState = 5;
        }

        private static void train()
        {
            for (int i = 0; i < Iterations; i++)
            {
                for (int start = 0; start < learning.States.Count; start++)
                {
                    learning.LearnEpisode(random, start);
                }
            }

            Console.WriteLine("各行動の価値(いわゆるQ値) : ");

            foreach (State state in learning.States)
            {
                Console.WriteLine("状態{0} : ", learning.GetStateIndex(state));

                foreach (
                    Action action
                    in state.Actions.OrderByDescending((a) => a.ActionValue)
                    )
                {
                    Console.WriteLine(
                        "\t状態{0}へ移動 : {1}",
                        learning.GetStateIndex(action.Destination),
                        action.ActionValue);
                }

                Console.WriteLine();
            }

            Console.WriteLine();
        }

        private static void test()
        {
            Console.WriteLine("最短ルート:");

            for (int start = 0; start < learning.States.Count; start++)
            {
                learning.CurrentState = learning.States[start];

                do
                {
                    Console.Write(getCurrentStateNumber() + ", ");
                    learning.DoBestAction();
                }
                while (!learning.IsOnGoal());

                Console.WriteLine(getCurrentStateNumber());
            }
        }

        private static int getCurrentStateNumber()
        {
            return learning.GetStateIndex(learning.CurrentState);
        }
    }
}

QLearning.cs

using System;
using System.Collections.Generic;

namespace QLearningDemo
{
    class QLearning
    {
        public State CurrentState = null;
        public List<State> States = new List<State>();
        public int? GoalState = null;

        public void AddAction(int from, int to, int reward)
        {
            States[from].Actions.Add(
                    new Action(States[to], reward)
                    );
        }

        public void LearnEpisode(Random random, int initialState)
        {
            this.CurrentState = States[initialState];

            do
            {
                transitState(random);
            }
            while (!IsOnGoal());
        }

        private void transitState(Random random)
        {
            Action randomAction = CurrentState.GetRandomAction(random);
            randomAction.Learn();
            this.CurrentState = randomAction.Destination;
        }

        public bool IsOnGoal()
        {
            return CurrentState == States[GoalState.Value];
        }

        public void DoBestAction()
        {
            this.CurrentState = CurrentState
                    .GetBestAction()
                    .Destination;
        }

        public int GetStateIndex(State state)
        {
            return States.IndexOf(state);
        }
    }
}

State.cs

using System;
using System.Collections.Generic;
using System.Linq;

namespace QLearningDemo
{
    class State
    {
        public List<Action> Actions = new List<Action>();

        public Action GetBestAction()
        {
            return Actions.OrderByDescending((a)=>a.ActionValue).First();
        }

        public Action GetRandomAction(Random random)
        {
            return Actions[random.Next(Actions.Count)];
        }
    }
}

Action.cs

namespace QLearningDemo
{
    class Action
    {
        public double ActionValue;
        public int Reward;
        public State Destination;

        const double LearningRate = 0.7;
        const double EffectFromDestination = 0.8;

        public Action(State destination, int reward)
        {
            this.Destination = destination;
            this.Reward = reward;
        }

        public void Learn()
        {
            double valueOfDestination = EffectFromDestination
                    * Destination.GetBestAction().ActionValue;
            this.ActionValue *= (1 - LearningRate);
            this.ActionValue +=
                LearningRate * (Reward + valueOfDestination);
        }
    }
}

結果

各行動の価値(いわゆるQ値) :
状態0 :
        状態4へ移動 : 353.854164915573

状態1 :
        状態5へ移動 : 442.334228421353
        状態3へ移動 : 283.07180061273

状態2 :
        状態3へ移動 : 283.065465820059

状態3 :
        状態1へ移動 : 353.861890375022
        状態4へ移動 : 353.840814701139
        状態2へ移動 : 226.436642861787

状態4 :
        状態5へ移動 : 442.334228788084
        状態0へ移動 : 283.077595817974
        状態3へ移動 : 283.061058149671

状態5 :
        状態5へ移動 : 427.926339480848
        状態1へ移動 : 350.728622125388
        状態4へ移動 : 340.295484032348


最短ルート:
0, 4, 5
1, 5
2, 3, 1, 5
3, 1, 5
4, 5
5, 5


拍手[0回]

PR