忍者ブログ

Memeplexes

プログラミング、3DCGとその他いろいろについて

かんたん!制限付きボルツマンマシン実装 実数(0から1)バージョン (C#)訂正その2、 0と1を反転 [Deep Learningシリーズ]

0と1を反転すると…

ここで制限(制約)付きボルツマンマシン(Restricted Boltzmann Machine : RBM)を実装しましたが、いくつか難点があります。
これが私の実装のバグのせいなのか制限付きボルツマンマシンそのものの特性のせいなのかはわかりませんが、データに0に近い値が多い場合、0と1を入れ替えると、つまりデータに1に近い値が多くなると、上手く学習しないのです。


というわけで0と1を入れ替えてもうまくいくようにしましょう。
しかしまず、具体的にどういうケースでおかしくなるのかを見てみましょう。
前回のプログラムで、次のような学習データを与えます。

1.0 0.5 1.0 1.0 1.0
1.0 1.0 1.0 0.5 1.0
1.0 1.0 0.5 1.0 1.0

そして、
学習係数×学習データ数 = 0.3
学習の繰り返し回数(Epoch)= 1000
隠れニューロンの数 = 4
とします。

そして学習し、学習後に、学習データそのものを与えて想起します。
なお、想起にはランダム性があるので、それぞれの学習データに付き10回繰り返しました。
すると次のようになりました(小数点以下2桁):

1.00    0.85    0.83    0.83    1.00
1.00    0.85    0.83    0.83    1.00
1.00    0.85    0.83    0.83    1.00
1.00    0.85    0.83    0.83    1.00
1.00    0.85    0.83    0.83    1.00
1.00    0.85    0.83    0.83    1.00
1.00    0.81    0.77    0.90    1.00
1.00    0.85    0.83    0.83    1.00
1.00    0.85    0.83    0.83    1.00
1.00    0.81    0.77    0.90    1.00

1.00    0.81    0.77    0.90    1.00
1.00    0.85    0.83    0.83    1.00
1.00    0.85    0.83    0.83    1.00
1.00    0.85    0.83    0.83    1.00
1.00    0.85    0.83    0.83    1.00
1.00    0.85    0.83    0.83    1.00
1.00    0.85    0.83    0.83    1.00
1.00    0.85    0.83    0.83    1.00
1.00    0.85    0.83    0.83    1.00
1.00    0.85    0.83    0.83    1.00

1.00    0.85    0.83    0.83    1.00
1.00    0.85    0.83    0.83    1.00
1.00    0.85    0.83    0.83    1.00
1.00    0.85    0.83    0.83    1.00
1.00    0.85    0.83    0.83    1.00
1.00    0.85    0.83    0.83    1.00
1.00    0.85    0.83    0.83    1.00
1.00    0.81    0.77    0.90    1.00
1.00    0.81    0.77    0.90    1.00
1.00    0.85    0.83    0.83    1.00

3つ結果があることがお分かりいただけるでしょう。
この3つはそれぞれ、学習データの3つに対応します。
学習データを読ませて連想したデータを、10個ずつ、ここでは出力しています。
(故にここには3 × 10 = 30行のデータがあります)

さてこの結果を見ると……混ざっています。
3つの0.5が他の1.0と混ざり混ざって0.85とかになっています。
つまり、全然うまく想起できていません。
そもそも学習できていないのです。
さてどうすればいいでしょうか?

ちなみに、同じプログラムで次のようなデータを与えたとします:

0.0 0.5 0.0 0.0 0.0
0.0 0.0 0.0 0.5 0.0
0.0 0.0 0.5 0.0 0.0

つまり、0が多いデータです。
上は1が多いデータだったので、0と1を反転したのです。
するとこうなります:

0.00    0.58    0.00    0.01    0.00
0.00    0.58    0.00    0.01    0.00
0.00    0.01    0.04    0.03    0.00
0.00    0.51    0.00    0.01    0.00
0.01    0.41    0.26    0.00    0.01
0.00    0.47    0.07    0.00    0.00
0.00    0.21    0.00    0.32    0.00
0.00    0.58    0.00    0.01    0.00
0.00    0.00    0.76    0.00    0.00
0.00    0.21    0.00    0.32    0.00

0.00    0.00    0.00    0.55    0.00
0.00    0.00    0.02    0.68    0.00
0.00    0.00    0.00    0.55    0.00
0.00    0.00    0.20    0.15    0.00
0.00    0.00    0.00    0.55    0.00
0.00    0.01    0.01    0.01    0.00
0.00    0.01    0.04    0.03    0.00
0.01    0.00    0.56    0.23    0.01
0.00    0.21    0.00    0.32    0.00
0.00    0.00    0.02    0.68    0.00

0.01    0.41    0.26    0.00    0.01
0.00    0.00    0.76    0.00    0.00
0.00    0.01    0.39    0.00    0.00
0.00    0.01    0.39    0.00    0.00
0.00    0.01    0.39    0.00    0.00
0.01    0.00    0.56    0.23    0.01
0.00    0.01    0.39    0.00    0.00
0.01    0.15    0.03    0.06    0.01
0.01    0.00    0.56    0.23    0.01
0.00    0.01    0.39    0.00    0.00

んー……まあまあ悪くはありません(良くもありませんが)。
想起できているといえば想起できています。
つまり、0が多いとそこそこうまく想起できており、1が多いと上手く想起できないのです。
少なくとも、この実装では。

上手く想起するには

適当にコードをいじっていると、なんだかうまくいきました。

  • ニューロンの発火状態の取る値を0~1にするのを止め、(ホップフィールドネットワークでよくあるように)-1~1にします。
  • 温度を低くします(0.1くらいに)。
  • 隠れニューロンの数を5に増やします。

あと、これは意味があるのかどうかわかりませんが、学習時の隠れニューロンの温度と想起時の隠れニューロンの温度を変えています。
想起時には温度を下げ、決定論的に動かしているのです。
でもこれは意味があるのかどうかわからないので、皆さんは鵜呑みにしないでくださいね。
以下にコードを載せます:

Program.cs

using System;
using System.Linq;

namespace RestrictedBoltzmannMachines.RealValue
{
    class Program
    {
        static void Main(string[] args)
        {
            double[][] trainingDataList = {
                new double[]{1.0, 0.5, 1.0, 1.0, 1.0},
                new double[]{1.0, 1.0, 0.5, 1.0, 1.0},
                new double[]{1.0, 1.0, 1.0, 0.5, 1.0},
		    };
            //trainingDataList = invert(trainingDataList);

            var hiddenNeuronCount = 5;
            var visibleNeuronCount = trainingDataList[0].Length;

            var restrictedBoltzmannMachine = new RestrictedBoltzmannMachine(
                visibleNeuronCount,
                hiddenNeuronCount,
                new Random(0)
                );

            var trainingEpochCount = 1000;
            var basicLearningRate = 0.1;

            // train
            for (int epoch = 0; epoch < trainingEpochCount; epoch++)
            {
                restrictedBoltzmannMachine.LearnFromData(
                    trainingDataList,
                    basicLearningRate / trainingDataList.Length
                    );
            }

            foreach (var input in trainingDataList)
            {
                restrictedBoltzmannMachine.SetVisibleNeuronValues(input);
                restrictedBoltzmannMachine.Associate();

                foreach (var output in restrictedBoltzmannMachine.VisibleNeurons)
                {
                    Console.Write("{0:F2}\t", output.Value);
                }

                Console.WriteLine();
            }
        }

        static double[][] invert(double[][] input)
        {
            return input.Select(i => i.Select(n => 1 - n).ToArray()).ToArray();
        }
    }
}

RestrictedBoltzmannMachine.cs

using System;
using System.Threading.Tasks;
using System.Linq;

namespace RestrictedBoltzmannMachines.RealValue
{
    public class RestrictedBoltzmannMachine
    {
        public SymmetricConnection[][] Connections;
        public VisibleNeuron[] VisibleNeurons;
        public HiddenNeuron[] HiddenNeurons;
        public double LearningTemperature = 0.1;

        public RestrictedBoltzmannMachine(int visibleNeuronCount, int hiddenNeuronCount, Random random) :
            this(
                SymmetricConnection.CreateRandomWeights(random, visibleNeuronCount, hiddenNeuronCount),
                new double[visibleNeuronCount],
                new double[hiddenNeuronCount],
                random)
        {
        }

        public RestrictedBoltzmannMachine(double[][] weights, double[] visibleBiases, double[] hiddenBiases, Random random)
        {
            this.VisibleNeurons = Neuron.CreateNeurons<VisibleNeuron>(visibleBiases);
            this.HiddenNeurons = Neuron.CreateNeurons<HiddenNeuron>(hiddenBiases);
            this.Connections = SymmetricConnection.CreateConnections(weights, VisibleNeurons, HiddenNeurons);
            Neuron.WireConnections(this.Connections);

            foreach (var neuron in this.VisibleNeurons)
            {
                neuron.Random = new Random(random.Next());
            }

            foreach (var neuron in this.HiddenNeurons)
            {
                neuron.Random = new Random(random.Next());
            }
        }

        public void SetVisibleNeuronValues(double[] visibleValues)
        {
            for (int i = 0; i < this.VisibleNeurons.Length; i++)
            {
                this.VisibleNeurons[i].Value = visibleValues[i];
            }
        }

        public double[] GetVisibleNeuronValues()
        {
            return VisibleNeurons.Select(n => n.Value).ToArray();
        }

        // for deep learning
        public void SetHiddenNeuronValues(double[] hiddenValues)
        {
            for (int i = 0; i < this.HiddenNeurons.Length; i++)
            {
                this.HiddenNeurons[i].Value = hiddenValues[i];
            }
        }

        public void LearnFromData(double[][] dataList, double learningRate, int freeAssociationStepCount = 1)
        {
            foreach (var data in dataList)
            {
                SetVisibleNeuronValues(data);
                wake(learningRate);
                sleep(learningRate, freeAssociationStepCount);
                endLearning();
            }
        }

        private void wake(double learningRate)
        {
            UpdateHiddenNeurons(LearningTemperature);
            learn(learningRate);
        }

        private void updateNeurons(Neuron[] neurons, double temperature)
        {
            Parallel.ForEach(neurons, neuron => neuron.Update(temperature));
        }

        public void UpdateVisibleNeurons()
        {
            updateNeurons(this.VisibleNeurons, this.LearningTemperature);
        }

        public void UpdateHiddenNeurons(double temperature)
        {
            updateNeurons(this.HiddenNeurons, temperature);
        }

        // for deep learning
        public void SetVisibleNeuronValues(RestrictedBoltzmannMachine other)
        {
            for (int i = 0; i < this.VisibleNeurons.Length; i++)
            {
                this.VisibleNeurons[i].Value = other.HiddenNeurons[i].Value;
            }
        }

        // for deep learning
        public void SetHiddenNeuronValues(RestrictedBoltzmannMachine other)
        {
            for (int i = 0; i < this.HiddenNeurons.Length; i++)
            {
                this.HiddenNeurons[i].Value = other.VisibleNeurons[i].Value;
            }
        }

        // for deep learning
        public double[] VisibleToHidden(double[] visibleNeuronValues)
        {
            SetVisibleNeuronValues(visibleNeuronValues);
            UpdateHiddenNeurons(0);
            return HiddenNeurons.Select(n => n.Value).ToArray();
        }

        private void learn(double learningRate)
        {
            Parallel.ForEach(Connections, connectionRow =>
            {
                foreach (var connection in connectionRow)
                {
                    connection.Learn(learningRate);
                }
            });

            Parallel.ForEach(VisibleNeurons, neuron => neuron.Learn(learningRate));
            Parallel.ForEach(HiddenNeurons, neuron => neuron.Learn(learningRate));
        }

        private void sleep(double learningRate, int freeAssociationStepCount)
        {
            doFreeAssociation(freeAssociationStepCount);
            learn(-learningRate);
        }

        // Gibbs sampling
        private void doFreeAssociation(int freeAssociationStepCount)
        {
            for (int step = 0; step < freeAssociationStepCount; step++)
            {
                UpdateVisibleNeurons();
                UpdateHiddenNeurons(LearningTemperature);
            }
        }

        private void endLearning()
        {
            Parallel.ForEach(Connections, connectionRow =>
            {
                foreach (var connection in connectionRow)
                {
                    connection.EndLearning();
                }
            });

            Parallel.ForEach(VisibleNeurons, neuron => neuron.EndLearning());
            Parallel.ForEach(HiddenNeurons, neuron => neuron.EndLearning());
        }

        public void Associate()
        {
            UpdateHiddenNeurons(0);
            UpdateVisibleNeurons();
        }
    }
}

Neuron.cs

using System;
using System.Collections.Generic;

namespace RestrictedBoltzmannMachines.RealValue
{
    public abstract class Neuron
    {
        public double Value;
        public double Bias;
        public double DeltaBias;
        public List<Synapse> Synapses = new List<Synapse>();
        public Random Random;

        public double SymmetricValue
        {
            get
            {
                return 2 * Value - 1;
            }
        }

        public abstract void Update(double temperature);
        public abstract void Learn(double learningRate);


        public void EndLearning()
        {
            this.Bias += this.DeltaBias;
            this.DeltaBias = 0;
        }

        public static T[] CreateNeurons<T>(double[] biases)
            where T : Neuron, new()
        {
            T[] result = new T[biases.Length];

            for (int i = 0; i < result.Length; i++)
            {
                result[i] = new T { Bias = biases[i] };
            }

            return result;
        }

        public static void WireConnections(SymmetricConnection[][] connections)
        {
            foreach (var connectionRow in connections)
            {
                foreach (var connection in connectionRow)
                {
                    Synapse hiddenConnection = new Synapse();
                    hiddenConnection.Connection = connection;
                    hiddenConnection.SourceNeuron = connection.VisibleNeuron;
                    connection.HiddenNeuron.Synapses.Add(hiddenConnection);

                    Synapse visibleConnection = new Synapse();
                    visibleConnection.Connection = connection;
                    visibleConnection.SourceNeuron = connection.HiddenNeuron;
                    connection.VisibleNeuron.Synapses.Add(visibleConnection);
                }
            }
        }

        protected double GetInputFromSourceNeurons()
        {
            double result = 0;

            for (int i = 0; i < Synapses.Count; i++)
            {
                var s = Synapses[i];
                result += s.Connection.Weight * s.SourceNeuron.SymmetricValue;
            }

            return result;
        }

        protected static double Sigmoid(double x, double temperature)
        {
            return 1.0 / (1.0 + Math.Exp(-x / temperature));
        }
    }

    public class VisibleNeuron : Neuron
    {
        public override void Update(double temperature)
        {
            this.Value = Sigmoid(GetInputFromSourceNeurons() + Bias, temperature);
        }

        public override void Learn(double learningRate)
        {
            this.DeltaBias += learningRate * this.SymmetricValue;
        }
    }

    public class HiddenNeuron : Neuron
    {
        public double Probability;

        public double SymmetricProbability
        {
            get
            {
                return 2 * Probability - 1;
            }
        }

        public override void Update(double temperature)
        {
            this.Probability = Sigmoid(GetInputFromSourceNeurons() + Bias, temperature);
            this.Value = nextBool(Random, this.Probability) ? 1 : 0;
        }

        private static bool nextBool(Random random, double rate)
        {
            return random.NextDouble() < rate;
        }

        public override void Learn(double learningRate)
        {
            this.DeltaBias += learningRate * this.SymmetricProbability;
        }
    }
}

Synapse.cs

using System;

namespace RestrictedBoltzmannMachines.RealValue
{
    public struct Synapse
    {
        public Neuron SourceNeuron;
        public SymmetricConnection Connection;
    }

    public class SymmetricConnection
    {
        public double Weight;
        public double DeltaWeight;
        public VisibleNeuron VisibleNeuron;
        public HiddenNeuron HiddenNeuron;

        public void Learn(double learningRate)
        {
            this.DeltaWeight +=
                learningRate * VisibleNeuron.SymmetricValue * HiddenNeuron.SymmetricProbability;
        }

        public void EndLearning()
        {
            this.Weight += this.DeltaWeight;
            this.DeltaWeight = 0;
        }

        public static double[][] CreateRandomWeights(Random random, int visibleNeuronCount, int hiddenNeuronCount)
        {
            var result = createJaggedArray<double>(visibleNeuronCount, hiddenNeuronCount);

            double a = 1.0 / visibleNeuronCount;

            for (int i = 0; i < visibleNeuronCount; i++)
            {
                for (int j = 0; j < hiddenNeuronCount; j++)
                {
                    result[i][j] = uniform(random, -a, a);
                }
            }

            return result;
        }

        private static T[][] createJaggedArray<T>(int visibleNeuronCount, int hiddenNeuronCount)
        {
            var result = new T[visibleNeuronCount][];

            for (int i = 0; i < visibleNeuronCount; i++)
            {
                result[i] = new T[hiddenNeuronCount];
            }

            return result;
        }

        private static double uniform(Random random, double min, double max)
        {
            return random.NextDouble() * (max - min) + min;
        }

        public static SymmetricConnection[][] CreateConnections(
            double[][] weights, 
            VisibleNeuron[] visibleNeurons,
            HiddenNeuron[] hiddenNeurons)
        {
            var result = createJaggedArray<SymmetricConnection>(visibleNeurons.Length, hiddenNeurons.Length);

            for (int i = 0; i < visibleNeurons.Length; i++)
            {
                for (int j = 0; j < hiddenNeurons.Length; j++)
                {
                    SymmetricConnection connection = new SymmetricConnection();
                    connection.Weight = weights[i][j];
                    connection.VisibleNeuron = visibleNeurons[i];
                    connection.HiddenNeuron = hiddenNeurons[j];
                    result[i][j] = connection;
                }
            }

            return result;
        }
    }
}

実行結果

すると次のようにな結果になります:

1.00    0.50    1.00    1.00    1.00
1.00    1.00    0.50    1.00    1.00
1.00    1.00    1.00    0.50    1.00

完璧な結果になりましたね。もうすこしエポック少なくしたいですが。

拍手[0回]

PR