忍者ブログ

Memeplexes

プログラミング、3DCGとその他いろいろについて

Actor-Criticアルゴリズムのサンプルコード(dart)

Actor-Criticコードを置いときます。


コード

TileCoding.dart

class TileCoding{
  List<double> _tile;

  TileCoding(int tileResolution){
    _tile = new List<double>.filled(tileResolution, 0.0);
  }
  
  void learn(double input, double expectedResult, double learningRate){
    int index = getIndex(input);
    _tile[index] += learningRate * (expectedResult - _tile[index]);
  }
  
  double getValue(double input){
    if(input >= 1.0){ return _tile.last; }
    
    int index = getIndex(input);
    return _tile[index];
  }

  int getIndex(double input) {
    int index = (input * _tile.length).toInt();
    return index;
  }
}

Brain.dart

import "dart:math";
import "TileCoding.dart";

enum Action{
  LEFT,
  RIGHT
}

class Brain{
  static final int tileResolution = 10;
  
  double consideringFuture = 0.9;
  Random random = new Random();
  
  TileCoding _leftAction = new TileCoding(tileResolution);
  TileCoding _rightAction = new TileCoding(tileResolution);
  TileCoding _valuePrediction = new TileCoding(tileResolution);

  double _previousEnvironment;
  Action _action;
  
  Action get action => _action;
  
  void makeDecision(double environment){
    double leftValue = _leftAction.getValue(environment);
    double rightValue = _rightAction.getValue(environment);
    
    double leftExp = exp(leftValue);
    double rightExp = exp(rightValue);
    
    double leftProbability = leftExp / (leftExp + rightExp);

    bool isLeft = _nextBool(leftProbability);
    this._action = isLeft ? Action.LEFT : Action.RIGHT;
    this._previousEnvironment = environment;
  }
  
  bool _nextBool(double probability){
    return random.nextDouble() < probability;
  }
  
  void reinforce(
                 double newEnvironment,
                 double reward,
                 double learningRate
                 ){
    double previousValue = _valuePrediction.getValue(_previousEnvironment);
    double currentValue = _valuePrediction.getValue(newEnvironment);
    
    double tdError = reward 
        + consideringFuture * currentValue 
        - previousValue;
    
    _valuePrediction.learn(
        _previousEnvironment,
        previousValue + tdError,
        learningRate
        );
    
    switch(_action){
      case Action.LEFT:
        _leftAction.learn(
            _previousEnvironment,
            _leftAction.getValue(_previousEnvironment) + tdError,
            learningRate
            );
        break;
      case Action.RIGHT:
        _rightAction.learn(
            _previousEnvironment,
            _rightAction.getValue(_previousEnvironment) + tdError,
            learningRate
            );
        break;
    }
  }
}

Robot.dart

import "Brain.dart";

class Robot{
  static final double learningRate = 0.9;
  static final double leftWall = 0.0;
  static final double rightWall = 1.0;
  
  Brain brain = new Brain();
  
  double position = 0.5;
  double velocity = 0.0;
  
  void update(double mousePosition){
    double environment = _getEnvironment(mousePosition);

    brain.makeDecision(environment);
    
    _updatePhysics(brain.action);

    double reward = _computeReward(mousePosition);
    brain.reinforce(
        _getEnvironment(mousePosition),
        reward,
        learningRate
        );
  }

  void _updatePhysics(Action action) {
    double acceleration = _computeAcceleration(action);
    velocity += acceleration;
    position += velocity;
    
    // walls
    if(position < leftWall){
      position = leftWall;
      velocity = 0.0;
    }
    if(position > rightWall){
      position = rightWall;
      velocity = 0.0;
    }
  }

  double _computeAcceleration(Action action) {
    double acceleration = 0.0;
    
    switch(action){
      case Action.LEFT:
        acceleration = -0.01;
        break;
      case Action.RIGHT:
        acceleration = 0.01;
        break;
    }
    
    const double airDrag = 0.3;
    acceleration += -velocity * airDrag;
    return acceleration;
  }

  double _computeReward(double mousePosition) {
    double reward = 1 - (mousePosition - position).abs();
    if(reward < 0.8){ reward = 0.0; }
    return reward;
  }

  double _getEnvironment(double mousePosition) {
    double subjectiveMouse = mousePosition - position;
    double environment = (subjectiveMouse + 1) / 2;
    return environment;
  }
}

index.dart

import "Robot.dart";

import "dart:html";
import "package:stagexl/stagexl.dart";

Robot robot = new Robot();
double mousePosition = 0.0;

CanvasElement canvas = querySelector("#canvas");
Sprite robotSprite = new Sprite();
Sprite mouseSprite = new Sprite();

void main(){
  canvas.style.border = "solid";
  
  RenderLoop renderLoop = new RenderLoop();
  Stage stage = new Stage(canvas);
  Sprite sprite = new Sprite();
  sprite.scaleX = canvas.width;
  sprite.scaleY = canvas.height;
  sprite.addChild(robotSprite);
  sprite.addChild(mouseSprite);
  stage.addChild(sprite);
  
  sprite.onMouseMove.listen((e){
    mousePosition = e.localX;
  });
  
  renderLoop.addStage(stage);
  renderLoop.start();
  
  animate(0.0);
}

void animate(double deltaTime){
  update();
  draw();
  window.animationFrame.then((dt)=> animate(dt));
}

void update(){
  robot.update(mousePosition);
}

void draw(){
  drawRobot();
  drawMouseMarker();
}

void drawMouseMarker() {
  mouseSprite.graphics.clear();
  mouseSprite.graphics.beginPath();
  mouseSprite.graphics.moveTo(mousePosition, 0);
  mouseSprite.graphics.lineTo(mousePosition, 1.0);
  mouseSprite.graphics.closePath();
  mouseSprite.graphics.strokeColor(Color.Black, 0.01);
}

void drawRobot() {
  robotSprite.graphics.clear();
  
  robotSprite.graphics.beginPath();
  robotSprite.graphics.rect(0, 0, 1.0, 1.0);
  robotSprite.graphics.closePath();
  robotSprite.graphics.fillColor(Color.White);
  
  robotSprite.graphics.beginPath();
  robotSprite.graphics.circle(robot.position, 0.5, 0.05);
  robotSprite.graphics.closePath();
  robotSprite.graphics.fillColor(Color.Orange);
}

index.html

<!DOCTYPE html>

<html>
  <head>
  	<meta charset="utf-8">
  	<meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>index</title>
  </head>
 
  <body>   
    <canvas id="canvas" width="400" height="400"></canvas>

    <script type="application/dart" src="index.dart"></script>
    <!-- for this next line to work, your pubspec.yaml file must have a dependency on 'browser' -->
    <script src="packages/browser/dart.js"></script>
  </body>
</html>

解説

これはおそらくもっとも簡単な形態のActor-Criticのサンプルコードだと思います。
Actor-Criticアルゴリズムは人の脳に似ています。

脳との対応

Brain.dartクラスを見てください。

  • _leftAction, _rightAction : 大脳皮質を表しています。具体的には運動野です。ロボットへの入力データを元に、どんな行動をするかを決定します。決断をするオブジェクトです。
  • _valuePrediction : 大脳基底核を表しています。ある状態がどのくらい「よい」ものなのかを記録します。より「よい」状態をもたらす行動を強化します。

ロボットによる選択

ロボットは基本ランダムな行動をします。
しかし、その確率はいくらでも変わるのです。
最初は50%50%の確率でも、最終的には99%と1%の確率になったりします。

このロボットの取れる選択肢は2つだけです。
左に行くか、右に行くか。
ちなみに一度に両方を選択することは出来ません。
必ず左か右か一つを選ばねばなりません。

もちろんすでに述べたように、左に行くか右に行くかは確率的です。
しかし、全くランダムというわけではありません。
その確率は_leftActionと_rightActionの出す数値によって変わるのです。
大きいほうが選ばれやすいです。

選択の責任

ですから、このロボットの自由意志は_leftActionと_rightActionの出す数値によって決まっていると言えます。
_leftActionが大きな数字を出せば、_leftActionが選ばれやすくなり、_rightActionが大きな数字を出せば_rightActionが選ばれやすくなります。
決定論的な数字の計算プロセスが自由意志なのです。
ロボットが右を選んだことによって誰かが死んだとしても、ロボットは刑事罰を受けるべきでしょう!

なるほど、_leftActionが0.6を出し、_rightActionが0.4を出したのに、乱数の要素で_rightActionが選ばれてしまった、という可能性はあります。
しかし、_rightActionの出す数字を0.0000000000000000001にすることも、ロボットには可能だったのです。
_rightActionの数字が0.4なんて不注意というものです。
前方不注意で誰かをひき殺す確率は0.4より小さかもしれませんが、だからといってひき殺してしまった時、前方不注意であったことが見逃されるわけではありません。

学習

もし何か行動をした直後の状態がおもったより良かったら、その行動を強化します。
_valuePrediction(大脳基底核)から信号が来て、直前に選ばれた_leftActionか_rightAction(大脳皮質)を強化します。
_valuePredictionはある状態の「良さ」の予測です。
ロボットがその状態をどのくらい良いと思っているのかを記録しています。

学習で重要なのは、「良いか悪いか」ではなく、「思ったより良かったか悪かったか」です。
もしあなたがテストで99点をとっても、100点取れると予測していたのなら、それは思ったより悪い結果なので、大脳基底核が大脳皮質を罰します。
もちろん1点であっても、0点だと予測していたのなら、大脳基底核が大脳皮質を褒めます。

「良いか悪いか」などどうでもいいのです。
思ったより良かったか悪かったかが大切なのです。
うつ病になりそうですが、これこそ生物が何億年も上手くやってきた方法なのです。

関連

価値マーカーを付けたデモはこちら。
ブラウザ上で実行できます。

拍手[1回]

PR