백준 제곱수의 합

어떤 자연수 N은 그보다 작거나 같은 제곱수들의 합으로 나타낼 수 있다. 예를 들어 11=32+12+12(3개 항)이다. 이런 표현방법은 여러 가지가 될 수 있는데, 11의 경우 11=22+22+12+12+12(5개 항)도 가능하다. 이 경우, 수학자 숌크라테스는 “11은 3개 항의 제곱수 합으로 표현할 수 있다.”라고 말한다. 또한 11은 그보다 적은 항의 제곱수 합으로 표현할 수 없으므로, 11을 그 합으로써 표현할 수 있는 제곱수 항의 최소 개수는 3이다.

주어진 자연수 N을 이렇게 제곱수들의 합으로 표현할 때에 그 항의 최소개수를 구하는 프로그램을 작성하시오.

자연수 n을 i^2 의 합으로 표현할 때, 항의 최소 개수를 구하는 문제이다.

첫번째 풀이

n이 제곱이면 dp[n] = 2 n이 제곱이 아니라면 dp[n] = min((dp[n - 1] + dp[1]) .... dp[n - i] + dp[i]) (n - i => i)

첫번 째 풀이는 위와 같은 공식을 세워서 풀었다.

1부터 12까지를 예시로 살펴보자

  • 1 (1) = 1
  • 2 (1 + 1) = 1 + 1
  • 3 (2 + 1) = 1 + 1 + 1
  • 4 (2^2) = 2
  • 5 (4 + 1) = 2 + 1
  • 6 (4 + 2) = 2 + 1 + 1
  • 7 (4 + 3) = 2 + 1 + 1 + 1
  • 8 (4 + 4) = 2 + 2
  • 9 (3^2) = 3
  • 10 (9 + 1) = 3 + 1
  • 11 (9 + 2) = 3 + 1 + 1
  • 12 (9 + 3) = 3 + 1 + 1 + 1
    • 그러나, 12는 (8 + 4) = 2 + 2 + 2가 최소가 된다.
    • 이 부분 때문에 12를 조합할 수 있는 경우를 모두 고려했다.
    • dp[n] = min((dp[n - 1] + dp[1]) .... dp[n - i] + dp[i]) (n - i => i)
    • (11 + 1), (10 + 2), (9 + 3), (8 + 4) ...... (6 + 5)
@Test
public void test() {
    assertThat(solution(7))
            .isEqualTo(4);
    assertThat(solution(11))
            .isEqualTo(3);
    assertThat(solution(12))
            .isEqualTo(3);
    assertThat(solution(100_000))
            .isEqualTo(2);
    assertThat(solution(10_000))
            .isEqualTo(1);
    assertThat(solution(142))
            .isEqualTo(3);
}

private static int solution(int n) {
    final int MAX_SQUARE = 100_001;
    final int MAX_SQRT = (int) Math.sqrt(MAX_SQUARE);
    boolean[] isSquare = new boolean[MAX_SQUARE];
    int[] dp = new int[n + 1];

    isSquare[1] = true;
    dp[1] = 1;

    for (int i = 2; i <= n; i++) {
        if (isSquare[i]) {
            dp[i] = 1;
        } else {
            dp[i] = getMin(dp, i);
        }

        if (i <= MAX_SQRT) {
            isSquare[i * i] = true;
        }
    }

    return dp[n];
}

private static int getMin(int[] dp, int n) {
    int min = Integer.MAX_VALUE;
    for (int i = 1; n - i >= i; i++) {
        min = Math.min(min, dp[n - i] + dp[i]);
    }
    return min;
}

첫번째 반복문에서 O(n), getMin 함수에서 O(n / 2) 으로 꽤 오랜 시간이 걸리지만 통과를 하기는 한다. 어떻게 복잡도를 줄일 수 있을까?

두번째 풀이

다시 한번 예시를 살펴보자.

  • 1 (1) = 1
  • 2 (1 + 1) = 1 + 1
  • 3 (2 + 1) = 1 + 1 + 1
  • 4 (2 ^ 2) = 1
  • 5 (4 + 1) = 2 + 1
  • 6 (4 + 2) = 2 + 1 + 1
  • 7 (4 + 3) = 2 + 1 + 1 + 1
  • 8 (4 + 4) = 2 + 2
  • 9 (3 ^ 2) = 3
  • 10 (9 + 1) = 3 + 1
  • 11 (9 + 2) = 3 + 1 + 1
  • 12 (9 + 3) = 3 + 1 + 1 + 1

숫자 n에서 가장 근접한 제곱수를 m 이라고 가정한다. 각 숫자 n에서의 항의 개수 dp[n] 는 괄호에서 보이는 것과 같이 dp[m] + dp[n - m]이 된다.

그러나, 가장 근접한 제곱수를 선택하는 방법이 항상 정답은 아니다. 예외 케이스인 12의 경우 가장 근접한 제곱수를 선택하는 방법을 사용하면 4개가 나오지만, 사실 정답은 dp[8 + 4] = 3개이다.

12와 8의 관계는 어떻게 등장한걸까? 제곱수가 아닌 제곱근을 기준으로 살펴보자. 12의 제곱근은 3.xx 이다. 정수 3을 m이라고 가정해보자.

  • m = 1
    • m * m = 1
  • m = 2
    • m * m = 4
  • m = 3
    • m * m = 9

제곱근 m의 제곱수 (m * m)은 최소항의 개수가 항상 1이다. 그렇기 때문에 (m * m) 과 계산할 수 있는 식에서 최소항을 구할 수 있다. 12의 경우 1, 4, 9와 계산할 수 있는 식이 최소항을 가진다. 나머지 숫자에서 계산할 수 있는 식은 1, 4, 9의 경우의 항의 개수보다 크거나 같기 때문에 무시한다.

  • m = 1, 11 + 1 = 12
    • 4개
  • m = 2, 8 + 4 = 12
    • 3개
  • m = 3, 9 + 3 = 12
    • 4개
  • 무시하는 숫자
    • 10 + 2: 4개
    • 7 + 5: 6개
    • 6 + 6: 6개
    • .....
@Test
public void test() {
    assertThat(solution(7))
            .isEqualTo(4);
    assertThat(solution(11))
            .isEqualTo(3);
    assertThat(solution(12))
            .isEqualTo(3);
    assertThat(solution(100_000))
            .isEqualTo(2);
    assertThat(solution(10_000))
            .isEqualTo(1);
    assertThat(solution(142))
            .isEqualTo(3);
}

private static int solution(int n) {
    int [] dp = new int[n + 1];

    for(int i = 1; i <= n; i++) {
        dp[i] = i;
        for(int j = 1, pow = j * j; pow <= i; j++, pow = j * j)
            if(dp[i] > dp[i - pow] + 1) {
                dp[i] = dp[i - pow] + 1;
            }
    }
    return dp[n];
}

앞서 설명한 n, m, (m * m) 는 코드에서는 각각 i, j, pow 이다. dp[i - pow] + dp[pow]가 아닌 dp[i - pow] + 1와 같은 식을 사용한 이유는 어차피 pow는 제곱 수이며 dp[pow]는 1이기 때문이다.

두 번째 풀이에서는 첫 번째 풀이와 달리 i가 제곱 수인 경우를 고려할 필요가 없다. 4를 예로 들어보자면, i = 4, j = 2, i - pow = 0이 되며 결국 dp[0] + 1 = 0 + 1 = 1로 계산되기 때문이다.