<?xml version="1.0" encoding="utf-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom">
    <channel>
        <title>sang-yun.log</title>
        <link>https://velog.io/</link>
        <description>재수 고트</description>
        <lastBuildDate>Fri, 28 Nov 2025 12:10:05 GMT</lastBuildDate>
        <docs>https://validator.w3.org/feed/docs/rss2.html</docs>
        <generator>https://github.com/jpmonette/feed</generator>
        <image>
            <title>sang-yun.log</title>
            <url>https://velog.velcdn.com/images/sang-yun/profile/e85c362c-ff46-47df-973d-a42813d2a627/image.jpg</url>
            <link>https://velog.io/</link>
        </image>
        <copyright>Copyright (C) 2019. sang-yun.log. All rights reserved.</copyright>
        <atom:link href="https://v2.velog.io/rss/sang-yun" rel="self" type="application/rss+xml"/>
        <item>
            <title><![CDATA[함수형으로 MCTS를 만들어보자]]></title>
            <link>https://velog.io/@sang-yun/%ED%95%A8%EC%88%98%ED%98%95%EC%9C%BC%EB%A1%9C-MCTS%EB%A5%BC-%EB%A7%8C%EB%93%A4%EC%96%B4%EB%B3%B4%EC%9E%90</link>
            <guid>https://velog.io/@sang-yun/%ED%95%A8%EC%88%98%ED%98%95%EC%9C%BC%EB%A1%9C-MCTS%EB%A5%BC-%EB%A7%8C%EB%93%A4%EC%96%B4%EB%B3%B4%EC%9E%90</guid>
            <pubDate>Fri, 28 Nov 2025 12:10:05 GMT</pubDate>
            <description><![CDATA[<h2 id="서두">서두</h2>
<p>스칼라로 바둑 인공지능을 만들었다.</p>
<p>바둑 인공지능을 만들때는 대부분 MCTS를 사용한다.
MCTS는 몬테카를로 트리 서치의 약자인데, 트리를 생성하며 몬테카를로법을 효율적으로 적용하는 알고리즘이라고 보면 된다. select, expand, simulate, backpropagate의 네 단계로 구성된다. <a href="https://ko.wikipedia.org/wiki/%EB%AA%AC%ED%85%8C%EC%B9%B4%EB%A5%BC%EB%A1%9C_%ED%8A%B8%EB%A6%AC_%ED%83%90%EC%83%89">자세한 설명은 위키백과를 참고하자.</a></p>
<p>이번에는 MCTS를 함수형으로 구현해볼 것이다. (재미로)
<img src="https://velog.velcdn.com/images/sang-yun/post/37ea8eca-9a53-4626-a3d6-47ece520a27b/image.svg" alt="유명한 MCTS 도식"></p>
<h3 id="참고-바둑에-mcts를-적용할-때의-주의점">(참고) 바둑에 MCTS를 적용할 때의 주의점</h3>
<p><em>게임이 끝나는 조건을 구현하는게 빡세다. 왜냐하면 바둑은 규칙 상 두 선수가 동의하여 게임을 끝내는 것을 전제하기 때문이다. 그래서 휴리스틱을 적용해야 한다. 나는 돌의 세기를 집과 함께 계산하여 특정 시점에 (ex 200수 경과) 점수가 높은 선수가 승리하는 것으로 했다. 이 부분을 만들 때에는 바둑 지식이 꼭 필요하다.</em></p>
<h2 id="절차형으로-만들면">절차형으로 만들면</h2>
<p>절차형으로 만들고 싶다면 알고리즘의 네 단계를 그대로 코드에 옮기면 된다. 아무 생각 없이 만들면 된다.</p>
<h2 id="함수형으로-만들면">함수형으로 만들면</h2>
<p>세 가지 함수를 정의한다. 간단하게 표현하면 다음과 같다.
<code>select(state, node) = (updated_node, winner)</code>
<code>expand(state, node) = (updated_node, winner)</code>
<code>simulate(state) = winner</code>
현 상태와 노드를 받고 각 단계를 실행한 후의 새로 업데이트된 노드를 반환하는 것이다. 모두 순수함수이며 외부 상태에 의존하지 않는다.
그런데 <code>backpropagate()</code> 함수가 없다. 잘 생각해 보면 당연한데, backpropagate는 노드의 상태를 업데이트하는 단계이다. 즉, 함수형으로 작성하면 위의 세 가지 함수가 새로운 트리를 반환하는 것으로 대체한다.</p>
<h2 id="코드-더러움">코드 (더러움)</h2>
<pre><code class="language-scala">object MCTS {
  private val c = 1.5f
  case class MCTSNode(move: Move, turn: Int, children: List[MCTSNode], moves: List[Move], n: Int = 0, q: Float = 0) {
    val fullyExpanded: Boolean = moves.isEmpty

    def eval(p: Int): Float = q / n + c * math.sqrt(math.log(p).toFloat / n).toFloat

    override def toString: String = f&quot;$move: $q / $n&quot;
  }

  private def getBestNMoves(game: Baduk, n: Int, color: Int): List[Move] = {
    val size = game.get_board.size
    val moves =
      if (game.movesCount &gt; 5) game.get_moves()
      else if (game.movesCount &gt; 3) game.get_moves().filterNot(m=&gt; m.x == 0 || m.x == size - 1 || m.y == 0 || m.y == size - 1)
      else game.get_moves().filterNot(m=&gt; m.x &lt;= 1 || m.x &gt;= size - 2 || m.y &lt;= 1 || m.y &gt;= size - 2)

    val scoresPair = moves.map(m =&gt; (Eval(game.go(m)).valueBy(color), m))
    scoresPair.sortBy(tup =&gt; -tup._1).slice(0, Math.min(n, moves.length-1)).map(tup =&gt; tup._2)
  }

  private def filterMoves(moves: List[Move], game: Baduk): List[Move] = {
    val size = game.get_board.size
    val moves =
      if (game.movesCount &gt; 5) game.get_moves()
      else if (game.movesCount &gt; 3) game.get_moves().filterNot(m =&gt; m.x == 0 || m.x == size - 1 || m.y == 0 || m.y == size - 1)
      else game.get_moves().filterNot(m =&gt; m.x &lt;= 1 || m.x &gt;= size - 2 || m.y &lt;= 1 || m.y &gt;= size - 2)

    moves
  }

  @tailrec
  private def simulate(state: Baduk, limit: Int): Int = {
    if (state.end() || limit &lt;= 0)
      Eval(state).getWinner
    else {
      val moves = filterMoves(state.get_moves(), state)
      val next_state = state.go(moves(Random.nextInt(moves.length)))
      simulate(next_state, limit - 1)
    }
  }

  private def expand(state: Baduk, node: MCTSNode): (MCTSNode, Int, List[Move]) = {
    val move = node.moves(Random.nextInt(node.moves.length))
    val next_state = state.go(move)
    val winner = simulate(next_state, 100)
    val policyMoves = Policy(next_state).moves
    (MCTSNode(move, state.get_turn, List(), filterMoves(policyMoves.toList, next_state)/*getBestNMoves(next_state, 10, next_state.get_turn)*/, n = 1,
      q = if (winner == state.get_turn) 1 else if (winner == 0) 0.5f else 0), winner,
      node.moves.filterNot(elm =&gt; elm == move))
  }

  def select(state: Baduk, node: MCTSNode): (MCTSNode, Int) = {
    val delta_q = (winner: Int) =&gt; if (winner == node.turn) 1 else if (winner==0) 0.5f else 0
    if (state.end()) {
      val winner = state.winner()._1
      val new_node = node.copy(n=node.n+1, q=node.q + delta_q(winner))
      (new_node, winner)
    }
    else if (!node.fullyExpanded) {
      val (child, winner, moves) = expand(state, node)
      val new_node = node.copy(children= child::node.children, moves=moves, n=node.n+1, q=node.q + delta_q(winner))
      (new_node, winner)
    }
    else {
      val (max_child, max_index) = node.children.zipWithIndex.maxBy(_._1.eval(node.n))
      val (child, winner) = select(state.go(max_child.move), max_child)
      val new_node = node.copy(children = node.children.updated(max_index, child), n = node.n + 1, q= node.q + delta_q(winner))
      (new_node, winner)
    }
  }

  @tailrec
  private def searchNTimes(root_state: Baduk, node: MCTSNode, n: Int, visitSkip: Boolean=true, totalVisits: Int): MCTSNode = {
    //if (n % 1000 == 0)
    //  print(&#39;.&#39;)
    if (n == 0) node
    else
    {
      if (visitSkip &amp;&amp; node.children.length &gt;= 2 &amp;&amp; n % 100 == 0) {
        val max = node.children.maxBy(c =&gt; c.n)
        val second = node.children.filterNot(c =&gt; c == max).maxBy(c =&gt; c.n)
        val distance = max.n - second.n
        if (distance &gt; totalVisits){
          node
        }
        else {
          searchNTimes(root_state, select(root_state, node)._1, n - 1, totalVisits = totalVisits - 1)

        }
      } else{
        searchNTimes(root_state, select(root_state, node)._1, n - 1, totalVisits = totalVisits - 1)

      }
    }
  }
  def getContinuesMCTSRootNodeAndMaxFunction(rootTemp: MCTSNode, game: Baduk, visits: Int, totalVisits: Int): (MCTSNode, Move) = {
    val root =
      if (rootTemp == null)
        MCTSNode(game.null_move(), game.get_turn * -1, List(), Policy(game).moves.toList)
      else
        rootTemp
    val searchedRoot = searchNTimes(game, root, visits, totalVisits = totalVisits)
    //for(c &lt;- searchedRoot.children) println(s&quot;Move ${c.move}: ${c.q} / ${c.n}&quot;)
    val best_child = searchedRoot.children.maxBy(_.n)
    //println(s&quot;Move ${best_child.move}: ${best_child.q} / ${best_child.n} (${(best_child.q / best_child.n * 100).round}%)&quot;)
    (searchedRoot, best_child.move)
  }
}
</code></pre>
<h2 id="그래서-잘-둠">그래서 잘 둠?</h2>
<p>9줄 바둑에서 실행해보면 2~3 만 회 서치했을 때 바둑 좀 배운 초딩 수준으로 둔다. 그래도 별 다른 휴리스틱이나 인공신경망 없이 이 정도만 되어도 대단하다고 생각한다.</p>
<h3 id="참고-uct-함수의-이상한-점">(참고) UCT 함수의 이상한 점</h3>
<pre><code>private val c = 1.5f
def eval(p: Int): Float = q / n + c * math.sqrt(math.log(p).toFloat / n).toFloat
</code></pre><p>여기서 c값을 1.5로 해놨다. 그런데 보통 c 값은 루트2 아닌가?
사실은 1~3 사이 아무 값이나 상관 없다고 한다. 루트2를 쓰게 된 배경은 논문에서 증명하기 편한 숫자 아무거나 쓴 거라고 한다. 그러니 루트2로 할 이유는 전혀 없고 몇번 테스트해보아서 좋은 숫자로 하면 된다.</p>
]]></description>
        </item>
    </channel>
</rss>