Skip to content

페르마 소정리

소수 m, 임의의 수 a에 대해 아래 식이 성립한다.

a^(m-1) = 1 (mod m)

이 식을 활용해 역원을 구할 수 있다.

a^(m-1) = a*a^(m-2) = 1 (mod m)
즉, a^(m-2)가 a의 역원

증명

  1. a와 서로소인 소수 p에 대해 a, 2a, 3a, …, (p−1)a인 p−1개의 수를 p로 나눴을 때 나오는 나머지는 모두 다르다. 귀류법으로 증명된다.

    0<i<j<p인 정수에서 ia와 ja의 나머지가 같다고 하면

    ia != ja (mod p)
    ja - ia = 0 (mod p)
    • 위 식이 성립하기 위해선, (j-i)a가 p로 나눠떨어져야 한다.
    • a는 p와 서로소이므로 j-i가 p로 나눠떨어져야 한다.
    • 하지만 조건에 따라 i-j가 p보다 클 수 없다 (i-j < p-1). 성립하려면 i-j가 0, 즉 두 수가 같아야한다.
    • 모순 발생하므로 어떤 수와 0보다 크고 p보다 작은 수의 곱셈에 대해 p로 나눈 나머지는 항상 다르다.
  2. 아래와 같은 집합 A, B에서 집합 B는 p와 서로소인 수를 p로 나눌 때 생기는 모든 나머지들의 집합이다.

    A = { x | x=ia, i∈B }
    B = { 1, 2, ..., p−1 }
  3. a * 2a * 3a * ... * (p−1)a ≡ 1 * 2 * ... * (p−1) (mod p)
    (p−1)! * a^(p-1) ≡ (p-1)! (mod p)

    이므로, 양변을 (p−1)!로 나누면

    a^(m-1) = 1

    (m이 소수일 때 a와의 gcd가 1이라 곱셈 역원이 항상 존재하므로 합동의 양 변을 나눌 수 있다.)

알고리즘 분류: 페르마의 소정리

여담: 알고리즘 문제에서 결과를 10^9+7로 나눈 나머지로 출력시키는 이유

백준 등 프로그래밍 문제에서 결과를 10^9+7로 나눠 출력하라 지시하는 경우가 많다. 이 이유는 우선 쉽게 예상되듯 32비트 정수 오버플로우를 막기 위함이고, 하필 소수로 사용하는 이유는 소수여야 모듈러 곱셈 역원을 구할 수 있기 때문이다. 10^7+9는 10자리인 첫 번째 소수이기 때문에 자주 사용된다고 한다.

풀어 말하자면 실수에서 2를 나누는 것과 1/2를 곱하는 것 대신 군 안에서는 2를 나누는 대신 모듈러 했을 때의 곱셈 역원인 정수를 대신 곱해줄 수 있는데, 모듈러하는 값이 소수일 때만 곱셈 역원이 항상 존재한다. ( e.g. m=8, a=1, b=2에서 1/2 mod{8}을 구할 수 없음 )

팩토리얼, 조합 경우의 수와 같이 매우 큰 수를 계산할 때 곱셈 역원을 사용할 수 있다. 관련된 문제로 BOJ 16134 조합 (Combination)을 풀어볼 수 있다.

C(n, k) = n! x (k!)^-1 x ((n-k)!)^-1 (mod p)

페르마 소정리를 사용해

MOD = 1000000007
def power_mod(base, i):
ret = 1
base = base % MOD
while i>0:
if i%2==1:
result=result*base%mod
i=i>>1
base=base*base % MOD
return ret
def precompute_factorials(n):
"""0!부터 n!까지 전처리"""
fact = [1]*(n+1)
for i in range(1,n+1):
fact[i]=fact[i-1]*i%MOD
return fact
def combination_mod(n, r, fact):
"""C(n, r) mod mod (팩토리얼 전처리 사용)
C(n, r) = n! / (r! × (n-r)!)
≡ n! × (r!)^(-1) × ((n-r)!)^(-1) (mod mod)
"""
if r<0 or r>n:
return 0
if r==0 or r==n:
return 1
# n!
numerator = fact[n]
# r! × (n-r)!의 역원
# (r! × (n-r)!)^(-1) = (r!)^(-1) × ((n-r)!)^(-1)
d = fact[r] * fact[n-r] % MOD
inv_d = power_mod(d, MOD-2)
return n * inv_d % MOD
n, r = map(int, input().split())
fact = precompute_factorials(n)
print(combination_mod(n, r, fact))

참고