Computer Science/Problem Solving

[BOJ 17646] 제곱수의 합 2 (More Huge), 무서운 루비를 풀어보자.

리유나 2022. 12. 30. 04:16

문제 링크는 이쪽입니다.

 

0. 이 문제는 도대체 뭔가?

제가 BOJ에서 풀었던, 아니 PS를 하면서 풀었던 모든 문제를 통틀어서 가장 시행착오를 많이 겪었던 문제가 아마 바로 이 문제가 아닐까 싶습니다. 생각보다 유명한 문제여서 solved.ac 기준 루비 4임에도 불구하고 제법 많은 사람들이 풀었지만, 그럼에도 결코 쉬운 문제는 아닙니다. 개인적으로 적어도 난이도 값은 무조건 하는 문제라고 생각합니다.

 

문제의 내용 자체는 사실 아주 간단한데, 어떤 자연수 n에 대해서 n을 제곱수들의 합으로 나타내라는 것입니다. 사용된 제곱수의 개수는 최소여야 하고요. 가령, n이 25라면 3^2+4^2로도 나타낼 수 있겠지만 5^2 하나로 나타내는 것이 더욱 적합하겠죠.

 

그런데 n의 범위가 10^18까지의 자연수네요. 결국 이게 제일 이 문제에서 곤란한 부분입니다. 이 문제를 풀기 위해서 필요했던 사전 지식들과 접근한 방법을 설명하겠습니다.

 

1. 우선 제곱수 몇개의 합이면 될까?

최대 4개라고 했는데, 4개보다 적게 필요한데 4개라고 무작정 출력할 수는 없는 노릇이죠. n에 대해 최소 개수가 얼마인지를 먼저 판별해봅시다.

 

1개일 때: 자명하게도, n이 제곱수일 때이다. 제곱수인지 판별하는 건 폴라드 로를 써서 인수분해를 하든, 나이브하게 대강 0.5제곱 해보고 근처 언저리 숫자들로 테스트해보든 하면 금방 가능합니다.

 

2개일 때: 기본적으로 소수는 4k+1 꼴이면 언제나 제곱수 두 개의 합으로 나타내질 수 있습니다. 4k+3은 mod 4로 테스트 해보면 안되는 게 금방 나오고, 다른 경우는 2 하나 뿐인데 2 또한 1+1이니까 제곱수 두개의 합으로 나타내어집니다. 이 사실은 페르마가 증명했습니다.

여기서 조금 더 응용해서, 제곱수 두개의 합으로 나타낼 수 있는 숫자 두개를 곱하면 언제나 제곱수 두개의 합으로 나타낼 수 있습니다. x=a^2+b^2, y=c^2+d^2라고 한다면 xy=(ac-bd)^2+(bc+ad)^2로 나타낼 수 있기 때문입니다. 따라서, n을 소인수분해 했을 때 4k+3 꼴의 소수들은 전부 짝수개가 나온다면 그것들은 제곱수로 따로 빼주고, 나머지 인수들을 모두 두 제곱수의 합으로 나타내고 위의 처리를 거치면 됩니다.

 

3개일 때: 르장드르의 세 제곱수 정리에 따르면 4^n(8k+7)꼴의 자연수가 아닌 모든 수는 제곱수 세개의 합으로 나타낼 수 있습니다.이 경우에서의 처리는 여러 가지 방법이 있는데, 저는 조금 무식하게 아무 k를 잡는다 -> n^2-k^2가 두개로 나타내어지는지 테스트한다->될때까지 한다 라는 방법을 사용했습니다. 어떤 수가 제곱수 두개의 합으로 나타내어질 수 있는 가능성이 생각보다 높다는 힌트를 어디서 들은 것을 이용했네요.

 

이 과정에서 우선 n을 인수분해하고, 가능한 한 제곱수 인수들을 미리 나눠두고, 나중에 다시 곱해주는 식으로 푸는 것이 좋습니다. 제가 위에서 TLE가 엄청나게 많이 나왔는데, 이 경우에서 이 부분 처리가 미흡했던 것을 고쳐서 최종적으로 AC를 받았습니다. (아주 큰 소수 p에 대해서 3p^2를 넣었을 때 시간 초과가 나는 이슈였습니다.)

 

4개일 때: 4^n(8k+7)꼴일 때인데, 4^n은 제곱수니까 먼저 빼주고, 남은 4^n(8k+6)은 3개의 합으로 나타내질테니 위 과정을 거치면 됩니다.

2. 그래서 4k+1꼴의 소수를 제곱수 두개의 합으로 어떻게 나타내는데?

여기까지 왔다면 결국 가장 문제는 어떤 4k+1꼴의 소수를 두 제곱수의 합으로 나타내는 방법입니다. 여기서는 Cornacchia 알고리즘을 사용하면 됩니다. 추가적으로, 풀이 과정에서 x^2=-1(mod p)라는 모듈러 이차 방정식을 풀어야 하는 일이 생겨서, 이를 빠르게 풀 수 있는 Tonelli-Shanks 알고리즘을 사용하면 됩니다.

 

사실 이렇게만 적으면 루비 4치곤 쉬운 게 아닌가? 할 수도 있지만

- 학부 정수론 수준 이상의 기초 지식을 교양으로 가지고 있어야 하고

- 밀러-라빈 소수 판정법과 폴라드-로 알고리즘에 대한 기본적인 이해가 필요하고

- 기본적으로 PS 정수론에서 잘 쓰이지 않는 알고리즘 2개를 구현해야 하고

- 마지막으로, 결국 PS기 때문에 시간을 줄이기 위해서 열심히 최적화하고 예외 처리들을 해주고 구현해야 하는 과정들이

하나하나 결코 쉽지 않은 문제였다고 생각합니다.

 

그래도 이 문제를 풀면서 배운 수학쪽 지식이 상당히 있었기 때문에, 저는 좋은 방향으로 기억에 남는 문제였네요!

3. 코드

주의: 코드가 무지 깁니다.

더보기
import sys
from random import randrange
from math import gcd
input=sys.stdin.readline

def millerrabin(n,a):
    r=0
    d=n-1
    while not d%2:
        r+=1
        d//=2
    x=pow(a,d,n)
    if x==1 or x==n-1:return True
    for i in ' '*(r-1):
        x=pow(x,2,n)
        if x==n-1:return True
    return False

prime_mem = dict()
def isPrime(n):
    if n in prime_mem:return prime_mem[n]
    pList=[2, 7, 61] if n < 4759123141 else [2, 325, 9375, 28178, 450775, 9780504, 1795265022]
    if n==1:return False
    if n<4:return True
    if not n%2:return False
    res= n in pList or all(millerrabin(n, p) for p in pList)
    prime_mem[n]=res
    return res

def pollard(n):
    if n==1:return 1
    if n%2==0:return 2
    if isPrime(n):return n
    x=randrange(1,n)
    c=randrange(1,n)
    g,y=1,x
    while g==1:
        x=(x**2+c)%n
        y=(y**2+c)%n
        y=(y**2+c)%n
        g=gcd(abs(x-y),n)
    if g==n:return pollard(n)
    return g

facto_mem = dict()

def facto(n):
    if n in facto_mem:return facto_mem[n]
    if n==1:
        facto_mem[n]=[]
        return []
    if n%2==0:
        facto_mem[n]=[2]+facto(n//2)
        return facto_mem[n]
    if isPrime(n):
        facto_mem[n]=[n]
        return [n]
    k=pollard(n)
    facto_mem[n] = facto(k)+facto(n//k)
    return facto_mem[n]

def sqrt(n):
    L=facto(n)
    d=dict()
    for i in L:
        if i not in d:d[i]=1
        else:d[i]+=1
    res=1
    for i in d:
        res*=pow(i,d[i]//2)
    return res

def howmany(n):
    f_list=facto(n)
    facto_dict=dict()
    for i in f_list:
        if i in facto_dict:facto_dict[i]+=1
        else:facto_dict[i]=1
    state=True
    state_2=True
    for d in facto_dict:
        if facto_dict[d]%2:state=False
        if d%4==3 and facto_dict[d]%2:state_2=False
    
    if state:return 1
    if state_2:return 2

    while n%4==0:n//=4
    if n%8!=7:return 3
    return 4

def tonelli(p):
    #assume that solve x^2=-1(mod p)
    q=p-1
    s=0
    while q%2==0:
        q//=2
        s+=1
    z=randrange(2,p)
    while pow(z, (p-1)//2, p)==1:
        z=randrange(2,p)
    m=s
    c=pow(z,q,p)
    t=pow(p-1,q,p)
    r=pow(p-1,(q+1)//2,p)

    if t==0:return 0
    while t!=1 and t!=0:
        tt=t
        i=0
        while t%p!=1:
            t=pow(t,2,p)
            i+=1
        b=pow(c, pow(2, m-i-1, p), p)
        m=i
        c=pow(b,2,p)
        t=(tt*c)%p
        r*=b
        r%=p
    return r

def cornacchia(p):
    #to solve x^2+y^2=p
    if p%4==3:return False
    if p==2:return 1
    r=tonelli(p)
    rr=p
    while r**2>p:
        rr%=r
        if rr**2<p:return rr
        r%=rr
    return r

def get_1(n):
    return [sqrt(n)]

def get_2(n):
    mem=facto(n)
    facto_dict=dict()
    for i in mem:
        if i in facto_dict:facto_dict[i]+=1
        else:facto_dict[i]=1
    mem=facto_dict
    mul=1
    L=[]
    for d in mem:
        mul*=pow(d,mem[d]//2)
        if mem[d]%2:L.append(d)
    ans=[1,0]
    for dd in L:
        k=cornacchia(dd)
        res=[k, sqrt(dd-k**2)]
        a,b=ans
        c,d=res
        ans=[a*d+b*c, abs(a*c-b*d)]
    
    ans[0]*=mul
    ans[1]*=mul
    return ans

def get_3(n):
    f_list=facto(n)
    facto_dict=dict()
    for i in f_list:
        if i in facto_dict:facto_dict[i]+=1
        else:facto_dict[i]=1
    mul=1
    new_n=1
    for d in facto_dict:
        mul*=pow(d, facto_dict[d]//2)
        new_n*=pow(d, facto_dict[d]%2)
    t=1
    while howmany(new_n-t**2)!=2:
        t+=1
    res = get_2(new_n-t**2)+[t]
    res[0]*=mul
    res[1]*=mul
    res[2]*=mul
    return res

def get_4(n):
    ct=0
    while n%4==0:
        ct+=1
        n//=4
    res=get_3(n-1)
    pow_ct=pow(2,ct)
    res[0]*=pow_ct
    res[1]*=pow_ct
    res[2]*=pow_ct
    return res+[pow_ct]

n=int(input())
k=howmany(n)
print(k)
if k==1:
    for i in get_1(n):print(i,end=' ')
if k==2:
    for i in get_2(n):print(i,end=' ')
if k==3:
    for i in get_3(n):print(i,end=' ')
if k==4:
    for i in get_4(n):print(i,end=' ')
print()

'Computer Science > Problem Solving' 카테고리의 다른 글

2023 ICPC Korea regional 인터넷 예선 후기  (4) 2023.10.23
[BOJ 27533] 따로 걸어가기  (0) 2023.04.24
[BOJ 13522] 악마의 수열  (2) 2022.12.23
[BOJ 2022] 사다리  (0) 2022.12.12
solved.ac CLASS 8 달성  (0) 2022.12.06