[PR]
×
[PR]上記の広告は3ヶ月以上新規記事投稿のないブログに表示されています。新しい記事を書く事で広告が消えます。
プログラミング、3DCGとその他いろいろについて
[PR]上記の広告は3ヶ月以上新規記事投稿のないブログに表示されています。新しい記事を書く事で広告が消えます。
Deep Learningが十分に重い事がわかったので、パフォーマンスを追求してみます。
LinqのSumメソッドとList<T>、T[]の手続き的な合計を比較してみます。
using System; using System.Collections.Generic; using System.Linq; using System.Diagnostics; public struct Synapse { public Neuron SourceNeuron; public SymmetricConnection Connection; } public class SymmetricConnection { public double Weight; public double DeltaWeight; public Neuron VisibleNeuron; public Neuron HiddenNeuron; public void Learn(double learningRate) { this.DeltaWeight += learningRate * VisibleNeuron.Value * HiddenNeuron.Probability; } public void EndLearning() { this.Weight += this.DeltaWeight; this.DeltaWeight = 0; } public static double[][] CreateRandomWeights(Random random, int visibleNeuronCount, int hiddenNeuronCount) { var result = createJaggedArray<double>(visibleNeuronCount, hiddenNeuronCount); double a = 1.0 / visibleNeuronCount; for (int i = 0; i < visibleNeuronCount; i++) { for (int j = 0; j < hiddenNeuronCount; j++) { result[i][j] = uniform(random, -a, a); } } return result; } private static T[][] createJaggedArray<T>(int visibleNeuronCount, int hiddenNeuronCount) { var result = new T[visibleNeuronCount][]; for (int i = 0; i < visibleNeuronCount; i++) { result[i] = new T[hiddenNeuronCount]; } return result; } private static double uniform(Random random, double min, double max) { return random.NextDouble() * (max - min) + min; } public static SymmetricConnection[][] CreateConnections(double[][] weights, Neuron[] visibleNeurons, Neuron[] hiddenNeurons) { var result = createJaggedArray<SymmetricConnection>(visibleNeurons.Length, hiddenNeurons.Length); for (int i = 0; i < visibleNeurons.Length; i++) { for (int j = 0; j < hiddenNeurons.Length; j++) { SymmetricConnection connection = new SymmetricConnection(); connection.Weight = weights[i][j]; connection.VisibleNeuron = visibleNeurons[i]; connection.HiddenNeuron = hiddenNeurons[j]; result[i][j] = connection; } } return result; } } public class Neuron { public double Value; public double Probability; public double Bias; public double DeltaBias; public List<Synapse> Synapses = new List<Synapse>(); private Synapse[] SynapseArray; public Random Random; public void UpdateByLinq() { this.Probability = sigmoid(GetInputFromSourceNeuronsByLinq() + Bias); this.Value = nextBool(Random, this.Probability) ? 1 : 0; } private double GetInputFromSourceNeuronsByLinq() { return Synapses.Sum(s => s.Connection.Weight * s.SourceNeuron.Value); } public void UpdateByList() { this.Probability = sigmoid(GetInputFromSourceNeuronsByList() + Bias); this.Value = nextBool(Random, this.Probability) ? 1 : 0; } private double GetInputFromSourceNeuronsByList() { double result = 0; for (int i = 0; i < SynapseArray.Length; i++) { var s = Synapses[i]; result += s.Connection.Weight * s.SourceNeuron.Value; } return result; } public void UpdateBySynapseArray() { this.Probability = sigmoid(GetInputFromSourceNeuronsBySynapseArray() + Bias); this.Value = nextBool(Random, this.Probability) ? 1 : 0; } private double GetInputFromSourceNeuronsBySynapseArray() { double result = 0; for (int i = 0; i < SynapseArray.Length; i++) { var s = SynapseArray[i]; result += s.Connection.Weight * s.SourceNeuron.Value; } return result; } public static Neuron[] CreateNeurons(double[] biases) { Neuron[] result = new Neuron[biases.Length]; for (int i = 0; i < result.Length; i++) { result[i] = new Neuron { Bias = biases[i] }; } return result; } public static void WireConnections(SymmetricConnection[][] connections) { foreach (var connectionRow in connections) { foreach (var connection in connectionRow) { Synapse hiddenConnection = new Synapse(); hiddenConnection.Connection = connection; hiddenConnection.SourceNeuron = connection.VisibleNeuron; connection.HiddenNeuron.Synapses.Add(hiddenConnection); Synapse visibleConnection = new Synapse(); visibleConnection.Connection = connection; visibleConnection.SourceNeuron = connection.HiddenNeuron; connection.VisibleNeuron.Synapses.Add(visibleConnection); } } } public void UpdateSynapseArray() { this.SynapseArray = Synapses.ToArray(); } private static double sigmoid(double x) { return 1.0 / (1.0 + Math.Exp(-x)); } private static bool nextBool(Random random, double rate) { if (rate < 0 || 1 < rate) return false; return random.NextDouble() < rate; } } class Program { static void Main(string[] args) { var visibleNeurons = Neuron.CreateNeurons(Enumerable.Range(0, 100).Select(i => (double)i).ToArray()); var hiddenNeurons = Neuron.CreateNeurons(Enumerable.Range(0, 1000).Select(i => (double)i).ToArray()); var connections = SymmetricConnection.CreateConnections( SymmetricConnection.CreateRandomWeights(new Random(0), visibleNeurons.Length, hiddenNeurons.Length), visibleNeurons, hiddenNeurons ); Neuron.WireConnections(connections); var random = new Random(0); foreach (var neuron in hiddenNeurons) { neuron.Random = new Random(random.Next()); neuron.UpdateSynapseArray(); } foreach (var neuron in visibleNeurons) { neuron.Random = new Random(random.Next()); neuron.UpdateSynapseArray(); } int loopCount = 1000; Stopwatch stopwatch = new Stopwatch(); stopwatch.Start(); for (int i = 0; i < loopCount; i++) { foreach (var neuron in visibleNeurons) { neuron.UpdateByLinq(); } } stopwatch.Stop(); Console.WriteLine("Linq : " + stopwatch.Elapsed.TotalMilliseconds); GC.Collect(); stopwatch.Restart(); for (int i = 0; i < loopCount; i++) { foreach (var neuron in visibleNeurons) { neuron.UpdateByList(); } } stopwatch.Stop(); Console.WriteLine("List : " + stopwatch.Elapsed.TotalMilliseconds); GC.Collect(); stopwatch.Restart(); for (int i = 0; i < loopCount; i++) { foreach (var neuron in visibleNeurons) { neuron.UpdateBySynapseArray(); } } stopwatch.Stop(); Console.WriteLine("Array : " + stopwatch.Elapsed.TotalMilliseconds); } }
結果はこちら:
Linq : 8770.8684 List : 1424.852 Array : 1097.6195
Linqは便利ですがやはり遅いですね。
私のDeep Learningのコードでは実はここがボトルネックになっています。
ここではLinqは使えませんね。
せっかくなので、ポインタも使ってみることにしました。
using System; using System.Diagnostics; using System.Linq; using System.Collections.Generic; class Program { static unsafe void Main(string[] args) { int loopCount = 10000; List<float> numberList = Enumerable.Range(0, 10000).Select(i => (float)i).ToList(); float[] numberArray = numberList.ToArray(); Stopwatch stopwatch = new Stopwatch(); float resultLinq = 0; stopwatch.Start(); for (int i = 0; i < loopCount; i++) { resultLinq += numberList.Sum(number => number * number); } stopwatch.Stop(); Console.WriteLine("Linq : " + stopwatch.Elapsed.TotalMilliseconds); Console.WriteLine("\t" + resultLinq); GC.Collect(); float resultList = 0; stopwatch.Restart(); for (int i = 0; i < loopCount; i++) { float result = 0; for (int index = 0; index < loopCount; index++) { float number = numberList[index]; result += number * number; } resultList += result; } stopwatch.Stop(); Console.WriteLine("List : " + stopwatch.Elapsed.TotalMilliseconds); Console.WriteLine("\t" + resultList); GC.Collect(); float resultArray = 0; stopwatch.Restart(); for (int i = 0; i < loopCount; i++) { float result = 0; for (int index = 0; index < loopCount; index++) { float number = numberArray[index]; result += number * number; } resultArray += result; } stopwatch.Stop(); Console.WriteLine("Array : " + stopwatch.Elapsed.TotalMilliseconds); Console.WriteLine("\t" + resultArray); GC.Collect(); float resultPointer = 0; stopwatch.Restart(); for (int i = 0; i < loopCount; i++) { float result = 0; fixed (float* numberPointer = numberArray) { for (int index = 0; index < loopCount; index++) { float number = numberPointer[index]; result += number * number; } } resultPointer += result; } Console.WriteLine("Pointer : " + stopwatch.Elapsed.TotalMilliseconds); Console.WriteLine("\t" + resultPointer); GC.Collect(); } }
実行結果はこうなります:
Linq : 5386.1902 3.333105E+15 List : 1051.6524 3.333105E+15 Array : 1047.8888 3.333105E+15 Pointer : 1099.8356 3.333105E+15
やはりLinqは重いです。
しかし思ったほどポインタは速くなりませんね。
きっとコンパイラの中で最適化されているのでしょう。