忍者ブログ

Memeplexes

プログラミング、3DCGとその他いろいろについて

GPUで疑似乱数 (DirectX 11)(Xorshift)

GPUで疑似乱数

人工知能の世界では今Deep Learningが流行りです。
そしてDeep Learningは重いのでGPUが向いています。
しかし制限(制約)付きボルツマンマシン(Restricted Boltzmann Machine:RBM)を使ったDeep Learningの場合、ネットワークの更新に乱数が必要です。
ニューロンが興奮するかどうかどうかを確率からランダムに決めるんですね。
ですから「GPUで乱数を得るには?」というのが今回のテーマです。
(CPU側で作った疑似乱数をGPUにどっと送ってもいいのですが、転送速度が気になります)

ここ数日OpenCLばっかり使っていたので今回はDirectX11を使ってあげましょう。
後でOpenCLでもやってみる気もしますけどね。


サンプルコード

Program.cs

using System.Linq;

using SlimDX;
using SlimDX.Direct3D11;
using SlimDX.D3DCompiler;

struct Xorshift128RandomGpu
{
    public int w, x, y, z;

	public Xorshift128RandomGpu(int seed)
    {
        if (seed == 0)
        {
            seed += 11;
        }

        w = seed;
        x = seed << 16 + seed >> 16;
        y = w + x;
        z = x ^ y;
    }
}

class Program
{
    const int ElementCount = 20;

    static Buffer ResultBuffer;
    static Buffer RandomGeneratorsBuffer;
    static ShaderBytecode updateRandomShaderBytecode;
    static ComputeShader updateRandomShader;
    static UnorderedAccessView[] unorderedAccessViews;

    static void Main(string[] args)
    {
        Device device = new Device(DriverType.Hardware);

        ResultBuffer = createStructuredBuffer(
            device,
            Enumerable.Range(0, ElementCount).ToArray(),
            BindFlags.UnorderedAccess
            );
        var random = new System.Random(0);
        RandomGeneratorsBuffer = createStructuredBuffer(
            device,
            Enumerable.Range(0, ElementCount).Select(i=> new Xorshift128RandomGpu(random.Next())).ToArray(),
            BindFlags.UnorderedAccess
            );

        updateRandomShaderBytecode = createShaderBytecode(
            device,
            System.IO.File.ReadAllText("MyShader.fx"),
            "UpdateRandom"
            );
        updateRandomShader = new ComputeShader(device, updateRandomShaderBytecode);
        device.ImmediateContext.ComputeShader.Set(updateRandomShader);
        unorderedAccessViews = new[] 
        { 
            new UnorderedAccessView(device, ResultBuffer),
            new UnorderedAccessView(device, RandomGeneratorsBuffer)
        };
        device.ImmediateContext.ComputeShader.SetUnorderedAccessViews(
            unorderedAccessViews, 0, unorderedAccessViews.Length
            );
        device.ImmediateContext.Dispatch(ElementCount, 1, 1);

        foreach (var number in readBuffer<float>(ResultBuffer))
        {
            System.Console.WriteLine(number);
        }

        unorderedAccessViews[0].Dispose();
        unorderedAccessViews[1].Dispose();
        updateRandomShader.Dispose();
        updateRandomShaderBytecode.Dispose();
        RandomGeneratorsBuffer.Dispose();
        ResultBuffer.Dispose();
        device.Dispose();
    }

    static ShaderBytecode createShaderBytecode(Device device, string source, string entryPoint)
    {
        return ShaderBytecode.Compile(
            source,
            entryPoint,
            "cs_5_0",
            ShaderFlags.None,
            EffectFlags.None
            );
    }

    static T[] readBuffer<T>(Buffer buffer) where T : struct
    {
        using (Buffer cpuAccessibleBuffer = createCpuAccessibleBuffer(buffer.Device, buffer.Description.SizeInBytes))
        {
            buffer.Device.ImmediateContext.CopyResource(buffer, cpuAccessibleBuffer);
            return readCpuAccessibleBuffer<T>(cpuAccessibleBuffer);
        }
    }

    static Buffer createStructuredBuffer<T>(Device device, T[] initialData, BindFlags bindFlags)
        where T : struct
    {
        using (DataStream initialDataStream = new DataStream(initialData, true, true))
        {
            return new Buffer(
                device,
                initialDataStream,
                new BufferDescription
                {
                    SizeInBytes = (int)initialDataStream.Length,
                    BindFlags = bindFlags,
                    OptionFlags = ResourceOptionFlags.StructuredBuffer,
                    StructureByteStride = System.Runtime.InteropServices.Marshal.SizeOf(typeof(T))
                }
                );
        }
    }

    static Buffer createCpuAccessibleBuffer(Device device, int sizeInBytes)
    {
        return new Buffer(
            device,
            new BufferDescription
            {
                SizeInBytes = sizeInBytes,
                CpuAccessFlags = CpuAccessFlags.Read,
                Usage = ResourceUsage.Staging
            }
            );
    }

    static T[] readCpuAccessibleBuffer<T>(Buffer from)where T : struct
    {
        DataBox data = from.Device.ImmediateContext.MapSubresource(
            from,
            0, 
            from.Description.SizeInBytes,
            MapMode.Read,
            MapFlags.None
            );
        return getArray<T>(data.Data);
    }

    static T[] getArray<T>(DataStream stream)where T : struct
    {
        T[] buffer = new T[stream.Length / System.Runtime.InteropServices.Marshal.SizeOf(typeof(T))];
        stream.ReadRange(buffer, 0, buffer.Length);
        return buffer;
    }
}

myShader.fx

class Xorshift128Random
{
	int w;
	int x;
	int y;
	int z;	
	static const uint INT_MAX = 1 << (32) - 1;

	static int Next(inout Xorshift128Random random)
	{
		int t = (random.x ^ (random.x << 11));
		random.x = random.y;
		random.y = random.z;
		random.z = random.w;
		random.w = (random.w = (random.w ^ (random.w >> 19)) ^ (t ^ (t >> 8)));
		return random.w;
	}

	static float NextFloat(inout Xorshift128Random random)
	{
		return ((float)Next(random) / INT_MAX);
	}
};

RWStructuredBuffer<float> ResultBuffer : register(u0);
RWStructuredBuffer<Xorshift128Random> RandomGenerators : register(u1);

[numthreads(1, 1, 1)]
void UpdateRandom(uint3 threadID : SV_DispatchThreadID)
{
	Xorshift128Random random = RandomGenerators[threadID.x];
	ResultBuffer[threadID.x] = Xorshift128Random::NextFloat(random);
	RandomGenerators[threadID.x] = random;
}

これはXorshiftというアルゴリズムを使って、GPUで疑似乱数を生成し、CPU側に戻すコードです。
ちなみに最初のSeedはCPU側から送ってあげる必要があります。
これを実行すると次のような結果になります:

0.7322494
0.8232476
0.7558671
0.5355725
0.2103059
0.5624865
0.8987527
0.4662975
0.998933
0.2610296
0.2951466
0.4376118
0.6250286
0.4811476
0.9961796
0.01847547
0.8486539
0.9736332
0.671631
0.3234635

なんとなくうまく動いているような気がします。

拍手[0回]

PR