코드 저장소

공부에는 끝이 없다!

JAVA/코딩 테스트

(JAVA) 백준 문제 풀이 - 분할 정복 단계 - 1629번 곱셈 (실버 1)

VarcharC2K 2023. 12. 14. 16:40

이번 문제는 A,B,C 세가지 수가 주어지면 A를 B번 곱한 수를 C로 나눈 나머지를 출력하는 문제이다. 입력의 최대값은 2,147,483,647로 큰 값을 Long의 범위 안에서 어떻게 처리 할지가 핵심이었다. 그럼 어떻게 풀지 살펴보자.


수학적 지식이 필요하다

이번 문제를 풀기 위해선 수학적 지식이 조금 필요한데, 첫째는 지수 법칙이고 두번째는 모듈러 성질이다.

 

1. 지수 법칙 

2. 모듈러 성질 

이 2가지를 이용하여서 문제를 해결한다.

그럼 위의 공식과 분할 정복이 문제와 무슨 연관이 있느냐?

위 문제를 그냥 A*B %C로 푸는 경우 입력값이 최대인 2,147,483,647인 경우 Long의 범위를 넘어가게 된다.

따라서 지수를 보다 작은 값으로 나눠 줄 필요가 있는데, 이를 위하여 분할 정복과 지수 법칙이 사용된다.


분할 정복을 적용해 보자

지수 법칙에 따르면 지수를 반으로 나누어 곱한 값과 전체 값은 동일하다.

위의 이미지 처럼 지수가 1이 될때 까지 반으로 나누고 전체를 곱한 값이 원래의 값과 동일한 것이다.

따라서 우리는 지수를 반으로 나누어 1이 될때 까지 재귀시켜 값을 구함으로써 분할 정복을 수행할 수 있다.

코드로 살펴보자.

//a = 밑, exp = 지수
public static long cal(long a, long exp) {
        if (exp == 1) {
            return a % c;
        }

        long temp = cal(a, exp / 2);
  
  //지수가 홀수인 경우
     if (exp % 2 == 1) {
        return (temp * temp) * a;
    }
    
    //지수가 짝수인 경우
    return temp*temp;
}

 

코드를 보면 밑과 지수를 받아서 지수가 1이면 a를 c로 나누어 나머지를 반환하고 1이 아니면 반으로 나누어 재귀한다.

아래 쪽을 보면 홀수인 경우가 나와있는데 짝수인 경우 그냥 temp를 곱해주면 되지만 홀수인 경우 밑을 한번더 곱하여줘야 하기 때문이다.

여기서 문제가 발생하는데, 우리가 최종적으로 얻어야 하는 값은 C로 나눈 나머지 값이다.

지수가 홀수인 경우 temp * temp * a로 하게 되는데 입력값이 최대인 경우 이때 long의 범위를 넘어가 계산이 틀리게 된다.

따라서 모듈러 합동 공식을 사용하게 되는데 약간의 변형이 필요하다.

(temp * temp * A) % C = ((temp * temp % C) * (A % C)) % C
					  = (((temp * temp % C) % C) * (A % C)) % C
					  = ((temp * temp % C) * A) % C

따라서 위의 공식을 적용하면 최종 코드는 다음과 같다.


import java.io.*;
import java.util.*;

class Main{
    public static long a,b,c;
    public static String[] str;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        str = br.readLine().split(" ");
        a = Long.parseLong(str[0]);
        b = Long.parseLong(str[1]);
        c = Long.parseLong(str[2]);

        System.out.println(cal(a, b));
    }

    public static long cal(long a, long exp) {
        if (exp == 1) {
            return a % c;
        }

        long temp = cal(a, exp / 2);

        if (exp % 2 == 1) {
            return (temp * temp % c) * a % c;
        }

        return temp * temp % c;
    }
}

 

코드를 짜는거 자체는 어렵지 않았지만 수학적인 지식이 조금 필요하다 보니 애를 먹었던 문제 같다. 속도보다는 입력된 수의 범위를 좁히는 것에 초점이 맞춰져 있다 보니 새로웠던 문제였다.