忍者ブログ

Memeplexes

プログラミング、3DCGとその他いろいろについて

Linq vs List vs T[] vs ポインタ パフォーマンス対決

一番早いのはどれ?

Deep Learningが十分に重い事がわかったので、パフォーマンスを追求してみます。
LinqのSumメソッドとList<T>、T[]の手続き的な合計を比較してみます。

using System;
using System.Collections.Generic;
using System.Linq;
using System.Diagnostics;

public struct Synapse
{
    public Neuron SourceNeuron;
    public SymmetricConnection Connection;
}

public class SymmetricConnection
{
    public double Weight;
    public double DeltaWeight;
    public Neuron VisibleNeuron;
    public Neuron HiddenNeuron;

    public void Learn(double learningRate)
    {
        this.DeltaWeight +=
            learningRate * VisibleNeuron.Value * HiddenNeuron.Probability;
    }

    public void EndLearning()
    {
        this.Weight += this.DeltaWeight;
        this.DeltaWeight = 0;
    }

    public static double[][] CreateRandomWeights(Random random, int visibleNeuronCount, int hiddenNeuronCount)
    {
        var result = createJaggedArray<double>(visibleNeuronCount, hiddenNeuronCount);

        double a = 1.0 / visibleNeuronCount;

        for (int i = 0; i < visibleNeuronCount; i++)
        {
            for (int j = 0; j < hiddenNeuronCount; j++)
            {
                result[i][j] = uniform(random, -a, a);
            }
        }

        return result;
    }

    private static T[][] createJaggedArray<T>(int visibleNeuronCount, int hiddenNeuronCount)
    {
        var result = new T[visibleNeuronCount][];

        for (int i = 0; i < visibleNeuronCount; i++)
        {
            result[i] = new T[hiddenNeuronCount];
        }

        return result;
    }

    private static double uniform(Random random, double min, double max)
    {
        return random.NextDouble() * (max - min) + min;
    }

    public static SymmetricConnection[][] CreateConnections(double[][] weights, Neuron[] visibleNeurons, Neuron[] hiddenNeurons)
    {
        var result = createJaggedArray<SymmetricConnection>(visibleNeurons.Length, hiddenNeurons.Length);

        for (int i = 0; i < visibleNeurons.Length; i++)
        {
            for (int j = 0; j < hiddenNeurons.Length; j++)
            {
                SymmetricConnection connection = new SymmetricConnection();
                connection.Weight = weights[i][j];
                connection.VisibleNeuron = visibleNeurons[i];
                connection.HiddenNeuron = hiddenNeurons[j];
                result[i][j] = connection;
            }
        }

        return result;
    }
}

public class Neuron
{
    public double Value;
    public double Probability;
    public double Bias;
    public double DeltaBias;
    public List<Synapse> Synapses = new List<Synapse>();
    private Synapse[] SynapseArray;
    public Random Random;

    public void UpdateByLinq()
    {
        this.Probability = sigmoid(GetInputFromSourceNeuronsByLinq() + Bias);
        this.Value = nextBool(Random, this.Probability) ? 1 : 0;
    }

    private double GetInputFromSourceNeuronsByLinq()
    {
        return Synapses.Sum(s => s.Connection.Weight * s.SourceNeuron.Value);
    }

    public void UpdateByList()
    {
        this.Probability = sigmoid(GetInputFromSourceNeuronsByList() + Bias);
        this.Value = nextBool(Random, this.Probability) ? 1 : 0;
    }

    private double GetInputFromSourceNeuronsByList()
    {
        double result = 0;

        for (int i = 0; i < SynapseArray.Length; i++)
        {
            var s = Synapses[i];
            result += s.Connection.Weight * s.SourceNeuron.Value;
        }

        return result;
    }

    public void UpdateBySynapseArray()
    {
        this.Probability = sigmoid(GetInputFromSourceNeuronsBySynapseArray() + Bias);
        this.Value = nextBool(Random, this.Probability) ? 1 : 0;
    }

    private double GetInputFromSourceNeuronsBySynapseArray()
    {
        double result = 0;

        for (int i = 0; i < SynapseArray.Length; i++)
        {
            var s = SynapseArray[i];
            result += s.Connection.Weight * s.SourceNeuron.Value;
        }

        return result;
    }

    public static Neuron[] CreateNeurons(double[] biases)
    {
        Neuron[] result = new Neuron[biases.Length];

        for (int i = 0; i < result.Length; i++)
        {
            result[i] = new Neuron { Bias = biases[i] };
        }

        return result;
    }

    public static void WireConnections(SymmetricConnection[][] connections)
    {
        foreach (var connectionRow in connections)
        {
            foreach (var connection in connectionRow)
            {
                Synapse hiddenConnection = new Synapse();
                hiddenConnection.Connection = connection;
                hiddenConnection.SourceNeuron = connection.VisibleNeuron;
                connection.HiddenNeuron.Synapses.Add(hiddenConnection);

                Synapse visibleConnection = new Synapse();
                visibleConnection.Connection = connection;
                visibleConnection.SourceNeuron = connection.HiddenNeuron;
                connection.VisibleNeuron.Synapses.Add(visibleConnection);
            }
        }
    }

    public void UpdateSynapseArray()
    {
        this.SynapseArray = Synapses.ToArray();
    }

    private static double sigmoid(double x)
    {
        return 1.0 / (1.0 + Math.Exp(-x));
    }

    private static bool nextBool(Random random, double rate)
    {
        if (rate < 0 || 1 < rate) return false;
        return random.NextDouble() < rate;
    }
}

class Program
{
    static void Main(string[] args)
    {
        var visibleNeurons = Neuron.CreateNeurons(Enumerable.Range(0, 100).Select(i => (double)i).ToArray());
        var hiddenNeurons = Neuron.CreateNeurons(Enumerable.Range(0, 1000).Select(i => (double)i).ToArray());
        var connections = SymmetricConnection.CreateConnections(
            SymmetricConnection.CreateRandomWeights(new Random(0), visibleNeurons.Length, hiddenNeurons.Length),
            visibleNeurons,
            hiddenNeurons
            );
        Neuron.WireConnections(connections);
        var random = new Random(0);

        foreach (var neuron in hiddenNeurons)
        {
            neuron.Random = new Random(random.Next());
            neuron.UpdateSynapseArray();
        }

        foreach (var neuron in visibleNeurons)
        {
            neuron.Random = new Random(random.Next());
            neuron.UpdateSynapseArray();
        }

        int loopCount = 1000;

        Stopwatch stopwatch = new Stopwatch();

        stopwatch.Start();

        for (int i = 0; i < loopCount; i++)
        {
            foreach (var neuron in visibleNeurons)
            {
                neuron.UpdateByLinq();
            }
        }

        stopwatch.Stop();

        Console.WriteLine("Linq : " + stopwatch.Elapsed.TotalMilliseconds);
        GC.Collect();

        stopwatch.Restart();

        for (int i = 0; i < loopCount; i++)
        {
            foreach (var neuron in visibleNeurons)
            {
                neuron.UpdateByList();
            }
        }

        stopwatch.Stop();

        Console.WriteLine("List : " + stopwatch.Elapsed.TotalMilliseconds);
        GC.Collect();

        stopwatch.Restart();

        for (int i = 0; i < loopCount; i++)
        {
            foreach (var neuron in visibleNeurons)
            {
                neuron.UpdateBySynapseArray();
            }
        }

        stopwatch.Stop();

        Console.WriteLine("Array : " + stopwatch.Elapsed.TotalMilliseconds);
    }
}

結果はこちら:

Linq : 8770.8684
List : 1424.852
Array : 1097.6195

単なる合計

Linqは便利ですがやはり遅いですね。
私のDeep Learningのコードでは実はここがボトルネックになっています。
ここではLinqは使えませんね。

せっかくなので、ポインタも使ってみることにしました。

using System;
using System.Diagnostics;
using System.Linq;
using System.Collections.Generic;

class Program
{
    static unsafe void Main(string[] args)
    {
        int loopCount = 10000;
        List<float> numberList = Enumerable.Range(0, 10000).Select(i => (float)i).ToList();
        float[] numberArray = numberList.ToArray();
        Stopwatch stopwatch = new Stopwatch();

        float resultLinq = 0;
        stopwatch.Start();


        for (int i = 0; i < loopCount; i++)
        {
            resultLinq += numberList.Sum(number => number * number);
        }

        stopwatch.Stop();

        Console.WriteLine("Linq : " + stopwatch.Elapsed.TotalMilliseconds);
        Console.WriteLine("\t" + resultLinq);
        GC.Collect();

        float resultList = 0;
        stopwatch.Restart();


        for (int i = 0; i < loopCount; i++)
        {
            float result = 0;

            for (int index = 0; index < loopCount; index++)
            {
                float number = numberList[index];
                result += number * number;
            }

            resultList += result;
        }

        stopwatch.Stop();

        Console.WriteLine("List : " + stopwatch.Elapsed.TotalMilliseconds);
        Console.WriteLine("\t" + resultList);
        GC.Collect();

        float resultArray = 0;
        stopwatch.Restart();

        for (int i = 0; i < loopCount; i++)
        {
            float result = 0;

            for (int index = 0; index < loopCount; index++)
            {
                float number = numberArray[index];
                result += number * number;
            }

            resultArray += result;
        }

        stopwatch.Stop();

        Console.WriteLine("Array : " + stopwatch.Elapsed.TotalMilliseconds);
        Console.WriteLine("\t" + resultArray);
        GC.Collect();

        float resultPointer = 0;
        stopwatch.Restart();

        for (int i = 0; i < loopCount; i++)
        {
            float result = 0;

            fixed (float* numberPointer = numberArray)
            {
                for (int index = 0; index < loopCount; index++)
                {
                    float number = numberPointer[index];
                    result += number * number;
                }
            }

            resultPointer += result;
        }

        Console.WriteLine("Pointer : " + stopwatch.Elapsed.TotalMilliseconds);
        Console.WriteLine("\t" + resultPointer);
        GC.Collect();
    }
}

実行結果はこうなります:

Linq : 5386.1902
        3.333105E+15
List : 1051.6524
        3.333105E+15
Array : 1047.8888
        3.333105E+15
Pointer : 1099.8356
        3.333105E+15

やはりLinqは重いです。
しかし思ったほどポインタは速くなりませんね。
きっとコンパイラの中で最適化されているのでしょう。

拍手[0回]

PR