忍者ブログ

Memeplexes

プログラミング、3DCGとその他いろいろについて

かんたん!制限付きボルツマンマシン実装 実数(0から1)バージョン (C#) [Deep Learningシリーズ]

ニューロンの発火状態が実数

さて今回はあまり自信がありません。
もっとも私はいつも自信がないのですが、今回はいつもよりさらに自信がありません。
プログラムは上手く動くように見えるのですが、その理論的背景がよくわかりません!
間違いがあったら訂正していただけるとありがたいです。


[追記] どうもこちらのほうが正しいような気がしてきました。リンク先もご覧ください。どちらの性能が高いかというとこのページのやり方のほうがいい気がします。うーんどっちを使うべきなんでしょう?

[追記2] こちらもご覧ください。この記事のプログラムは、学習データに1が多いと上手く学習できないようです。それを修正しました。

今回のテーマは制限(制約)付きボルツマンマシンの実装、実数バージョンです。
この前の実装は、バイナリでした。
つまり、可視ニューロンには発火しているかどうかの2通りしかありませんでした。
発火状態は0か1かしか扱えなかったのです。
(その割には可視ニューロンの発火状態の型が(boolではなく)doubleだったりと妙なことになっていましたが)



バイナリなペンギン
2通りしか無いと画像はこうなります。)




ともかく!
この前は白か黒か二通りしか無いモノクロビットマップしか扱えなかったのにたいして、今回は灰色あり濃淡様々な色のビットマップが扱えるようになったのです。
もちろん、扱えるデータはビットマップに限りません。
しかしこれはきっと画像を処理するときに良い感じだと思います。
いろんな色が見えると見た目綺麗ですからね。

プログラムのコード

プログラムは4つのファイルからなります。
  • Program.cs : 制限付きボルツマンマシンを動かします。
  • RestrictedBoltzmannMachine.cs : 制限付きボルツマンマシンです。
  • Neuron.cs : ニューロンです。
  • Synapse.cs : ニューロン間の結合、シナプスです。

Program.cs

今回は実数を扱うので、前回はなかった数字を入れてみました。
0.5です。
0.5を含んだパターンを学習させます。
さて上手く0.5を思い出してくれるでしょうか…?

using System;

namespace RestrictedBoltzmannMachines.RealValue
{
    class Program
    {
        static void Main(string[] args)
        {
            double[][] trainingData = {
                new double[]{1.0, 0.5, 1.0, 0.0, 0.0},
                new double[]{0.0, 0.0, 1.0, 0.5, 1.0},
                new double[]{0.0, 1.0, 0.5, 1.0, 0.0}
		    };

            var hiddenNeuronCount = 3;
            var visibleNeuronCount = trainingData[0].Length;

            var restrictedBoltzmannMachine = new RestrictedBoltzmannMachine(
                visibleNeuronCount,
                hiddenNeuronCount,
                new Random(0)
                );

            var trainingEpochCount = 1000;
            var basicLearningRate = 0.1;

            // train
            for (int epoch = 0; epoch < trainingEpochCount; epoch++)
            {
                foreach (var data in trainingData)
                {
                    restrictedBoltzmannMachine.SetVisibleNeuronValues(data);
                    restrictedBoltzmannMachine.LearnFromData(basicLearningRate / trainingData.Length);
                }
            }

            double[][] testInputData = {
			    new double[]{1, 1, 0, 0, 0},
			    new double[]{0, 0, 0, 1, 1}
		    };

            foreach (var input in testInputData)
            {
                restrictedBoltzmannMachine.SetVisibleNeuronValues(input);
                restrictedBoltzmannMachine.Associate();

                foreach (var output in restrictedBoltzmannMachine.VisibleNeurons)
                {
                    Console.Write("{0:F2}\t", output.Value);
                }
                
                Console.WriteLine();
            }
        }
    }
}

RestrictedBoltzmannMachine.cs

特に変わりありません。
ただ、疑似乱数は隠れニューロンだけが必要としているので、隠れニューロンに1つずつSystem.Randomを生成してセットしています。
なかなか意味不明ですが、マルチスレッドにしたとき、何度実行しても同じ結果を出すためにはこうするしか無いのではないでしょうか。

using System;
using System.Threading.Tasks;

namespace RestrictedBoltzmannMachines.RealValue
{
    public class RestrictedBoltzmannMachine
    {
        public SymmetricConnection[][] Connections;
        public VisibleNeuron[] VisibleNeurons;
        public HiddenNeuron[] HiddenNeurons;

        public RestrictedBoltzmannMachine(int visibleNeuronCount, int hiddenNeuronCount, Random random) :
            this(SymmetricConnection.CreateRandomWeights(random, visibleNeuronCount, hiddenNeuronCount), new double[visibleNeuronCount], new double[hiddenNeuronCount], random)
        {
        }

        public RestrictedBoltzmannMachine(double[][] weights, double[] visibleBiases, double[] hiddenBiases, Random random)
        {
            this.VisibleNeurons = Neuron.CreateNeurons<VisibleNeuron>(visibleBiases);
            this.HiddenNeurons = Neuron.CreateNeurons<HiddenNeuron>(hiddenBiases);
            this.Connections = SymmetricConnection.CreateConnections(weights, VisibleNeurons, HiddenNeurons);
            Neuron.WireConnections(this.Connections);

            foreach (var neuron in this.HiddenNeurons)
            {
                neuron.Random = new Random(random.Next());
            }
        }

        public void SetVisibleNeuronValues(double[] visibleValues)
        {
            for (int i = 0; i < this.VisibleNeurons.Length; i++)
            {
                this.VisibleNeurons[i].Value = visibleValues[i];
            }
        }

        public void LearnFromData(double learningRate, int freeAssociationStepCount = 1)
        {
            Wake(learningRate);
            Sleep(learningRate, freeAssociationStepCount);
            EndLearning();
        }

        public void Wake(double learningRate)
        {
            UpdateHiddenNeurons();
            learn(learningRate);
        }

        public void UpdateVisibleNeurons()
        {
            updateNeurons(this.VisibleNeurons);
        }

        public void UpdateHiddenNeurons()
        {
            updateNeurons(this.HiddenNeurons);
        }

        private void updateNeurons(Neuron[] neurons)
        {
            Parallel.ForEach(neurons, neuron => neuron.Update());
        }

        private void learn(double learningRate)
        {
            foreach (var connectionRow in Connections)
            {
                foreach (var connection in connectionRow)
                {
                    connection.Learn(learningRate);
                }
            }

            foreach (var neuron in this.VisibleNeurons)
            {
                neuron.Learn(learningRate);
            }

            foreach (var neuron in this.HiddenNeurons)
            {
                neuron.Learn(learningRate);
            }
        }

        public void Sleep(double learningRate, int freeAssociationStepCount)
        {
            doFreeAssociation(freeAssociationStepCount);
            learn(-learningRate);
        }

        //Gibbs sampling
        private void doFreeAssociation(int freeAssociationStepCount)
        {
            for (int step = 0; step < freeAssociationStepCount; step++)
            {
                UpdateVisibleNeurons();
                UpdateHiddenNeurons();
            }
        }

        public void EndLearning()
        {
            foreach (var connectionRow in Connections)
            {
                foreach (var connection in connectionRow)
                {
                    connection.EndLearning();
                }
            }

            foreach (var neuron in this.VisibleNeurons)
            {
                neuron.EndLearning();
            }

            foreach (var neuron in this.HiddenNeurons)
            {
                neuron.EndLearning();
            }
        }

        public void Associate()
        {
            UpdateHiddenNeurons();
            UpdateVisibleNeurons();
        }
    }
}

Neuron.cs

ニューロンのクラスです。
ここはかなり変わりました。
今回は隠れニューロンと可視ニューロンがだいぶ違うものになったので、もう別のクラスに分けました。

まず可視ニューロンには発火確率がありません。
可視ニューロンの発火状態は0から1までの実数なので当然です。
0か1というなら発火確率に意味はありますが、0から1の連続した値を取るものに発火確率とはナンセンスです(本当でしょうか?不安になって来ました)。

かわりに可視ニューロンの発火状況は隠れニューロンの発火確率の式と同じにしました。
シグモイド関数に、自分と接続しているニューロンからの入力を入れるのです。
これで正しいのかどうかはわかりませんが、ほかのサイトのプログラムもそうなっていることですし、まあいいでしょう。

using System;
using System.Collections.Generic;
using System.Linq;

namespace RestrictedBoltzmannMachines.RealValue
{
    public abstract class Neuron
    {
        public double Value;
        public double Bias;
        public double DeltaBias;
        public List<Synapse> Synapses = new List<Synapse>();

        public abstract void Update();
        public abstract void Learn(double learningRate);

        public void EndLearning()
        {
            this.Bias += this.DeltaBias;
            this.DeltaBias = 0;
        }

        public static T[] CreateNeurons<T>(double[] biases)
            where T : Neuron, new() 
        {
            T[] result = new T[biases.Length];

            for (int i = 0; i < result.Length; i++)
            {
                result[i] = new T { Bias = biases[i] };
            }

            return result;
        }

        public static void WireConnections(SymmetricConnection[][] connections)
        {
            foreach (var connectionRow in connections)
            {
                foreach (var connection in connectionRow)
                {
                    Synapse hiddenConnection = new Synapse();
                    hiddenConnection.Connection = connection;
                    hiddenConnection.SourceNeuron = connection.VisibleNeuron;
                    connection.HiddenNeuron.Synapses.Add(hiddenConnection);

                    Synapse visibleConnection = new Synapse();
                    visibleConnection.Connection = connection;
                    visibleConnection.SourceNeuron = connection.HiddenNeuron;
                    connection.VisibleNeuron.Synapses.Add(visibleConnection);
                }
            }
        }

        protected double GetInputFromSourceNeurons()
        {
            return Synapses.Sum(s => s.Connection.Weight * s.SourceNeuron.Value) + Bias;
        }

        protected static double Sigmoid(double x)
        {
            return 1.0 / (1.0 + Math.Exp(-x));
        }
    }

    public class VisibleNeuron : Neuron
    {
        public override void Update()
        {
            this.Value = Sigmoid(GetInputFromSourceNeurons());
        }

        public override void Learn(double learningRate)
        {
            this.DeltaBias += learningRate * this.Value;
        }
    }

    public class HiddenNeuron : Neuron
    {
        public double Probability;
        public Random Random;

        public override void Learn(double learningRate)
        {
            this.DeltaBias += learningRate * this.Probability;
        }

        public override void Update()
        {
            this.Probability = Sigmoid(GetInputFromSourceNeurons());
            this.Value = nextBool(Random, this.Probability) ? 1 : 0;
        }

        private static bool nextBool(Random random, double rate)
        {
            if (rate < 0 || 1 < rate) return false;
            return random.NextDouble() < rate;
        }
    }
}

Synapse.cs

シナプスのクラスです。
変わったのは、隠れニューロンと可視ニューロンの型くらいです。

using System;

namespace RestrictedBoltzmannMachines.RealValue
{
    public class Synapse
    {
        public Neuron SourceNeuron;
        public SymmetricConnection Connection;
    }

    public class SymmetricConnection
    {
        public double Weight;
        public double DeltaWeight;
        public VisibleNeuron VisibleNeuron;
        public HiddenNeuron HiddenNeuron;

        public void Learn(double learningRate)
        {
            this.DeltaWeight += 
                learningRate * VisibleNeuron.Value * HiddenNeuron.Probability;
        }

        public void EndLearning()
        {
            this.Weight += this.DeltaWeight;
            this.DeltaWeight = 0;
        }

        public static double[][] CreateRandomWeights(Random random, int visibleNeuronCount, int hiddenNeuronCount)
        {
            var result = createJaggedArray<double>(visibleNeuronCount, hiddenNeuronCount);

            double a = 1.0 / visibleNeuronCount;

            for (int i = 0; i < visibleNeuronCount; i++)
            {
                for (int j = 0; j < hiddenNeuronCount; j++)
                {
                    result[i][j] = uniform(random, -a, a);
                }
            }

            return result;
        }

        private static T[][] createJaggedArray<T>(int visibleNeuronCount, int hiddenNeuronCount)
        {
            var result = new T[visibleNeuronCount][];

            for (int i = 0; i < visibleNeuronCount; i++)
            {
                result[i] = new T[hiddenNeuronCount];
            }

            return result;
        }

        private static double uniform(Random random, double min, double max)
        {
            return random.NextDouble() * (max - min) + min;
        }

        public static SymmetricConnection[][] CreateConnections(
            double[][] weights, 
            VisibleNeuron[] visibleNeurons,
            HiddenNeuron[] hiddenNeurons)
        {
            var result = createJaggedArray<SymmetricConnection>(visibleNeurons.Length, hiddenNeurons.Length);

            for (int i = 0; i < visibleNeurons.Length; i++)
            {
                for (int j = 0; j < hiddenNeurons.Length; j++)
                {
                    SymmetricConnection connection = new SymmetricConnection();
                    connection.Weight = weights[i][j];
                    connection.VisibleNeuron = visibleNeurons[i];
                    connection.HiddenNeuron = hiddenNeurons[j];
                    result[i][j] = connection;
                }
            }

            return result;
        }
    }
}

実行結果

制限付きボルツマンマシンは、学習したデータを、ヒントを与えられることによって思い出します。
実行結果は思い出したデータです。
実行結果を見る前に、学習データとヒントをおさらいしましょう。
学習データは次のようなものでした:

1.0  0.5  1.0  0.0  0.0
0.0  0.0  1.0  0.5  1.0
0.0  1.0  0.5  1.0  0.0

そして与えたヒントはこうです:

1    1    0    0    0
0    0    0    1    1

では実行結果を見てみましょう。
実行結果はこうなります:

0.98    0.52    0.97    0.02    0.00
0.01    0.08    0.93    0.62    0.94

どうでしょうか。
まあだいたい上手く想起しているように見えませんか?
学習データの1番目と2番目を想起していますね。

拍手[0回]

PR