Computer Science/Problem Solving

[BOJ 1081] 합, Digit DP가 뭘까?

리유나 2023. 12. 1. 01:05

안녕하세요. 리유나입니다.

 

생각해보니 ICPC에서 J번을 그럭저럭 괜찮은 페이스로 자력 해결을 했는데, 그보다 훨씬 쉬우면서 비슷한 로직을 가진 1081번 합.이라는 문제가 있었고 제가 이걸 6년 전, PS에 거의 처음 입문했을 때 시도했다가 틀린뒤로 건드리지도 않았다는 사실을 깨달았습니다.

https://www.acmicpc.net/problem/1081

 

1081번: 합

L보다 크거나 같고, U보다 작거나 같은 모든 정수의 각 자리의 합을 구하는 프로그램을 작성하시오.

www.acmicpc.net

한번의 구현 실수를 제외하면, 아주 손쉽게 풀리는 스타일이었지만, 처음 입문하시는 분들이라면 다들 고등학교 시절의 저처럼 무언가 실수를 하실 부분이 있을듯 하니 도움이 되었으면 하는 바람에 포스팅을 시작합니다.

 

1. 나이브하게 구현하면 안될까?

안됩니다. 이게 됐으면 골드 1이 아니죠.

 

고등학교 3학년 때 제출했던 Python 2 코드가 이런 식으로 적혀 있는데, 이건 뭐 파이썬이 아니라 사실 무슨 언어를 쓰든 TLE가 날 수밖에 없는 코드입니다. 아니 그보다 저때는 무려 '런타임 에러'가 났더라고요? 대체 뭘 잘못했던 걸까요...

 

당신이 프로그래밍을 이제 막 처음 배우면서 for문 while문을 공부하고 있고, n이 한 100만 정도로 주어진다면 저런 코드는 아주 괜찮은 풀이 방법이 될 수 있습니다...만, 그렇다면 굳이 제가 이 글을 작성하고 있지도 않았을 것입니다.

 

시간 복잡도는 O(n)으로, 사실 웬만한 문제에서는 이정도면 정말로 준수한 시간 복잡도이지만, 이 문제는 20억까지의 숫자가 주어질 수 있다는 점을 체크합시다.

 

2. 어떻게 생각해야 할까?

일단 기본적으로 이런 류의 문제는 f(n): 1에서 n까지의 수들의 자릿수 합을 모두 더한 것. 이라고 생각하고 f(b)-f(a-1)을 구하는 방향이 대체로 현명합니다. 특히나 비슷한 문제인 5425 자리합은 input이 쿼리로 주어져서 더더욱 그렇고요. 그렇다면 f(n)을 구하는 문제로 자연스레 바꿔서 생각해볼 수 있겠네요. f(n)을 구하는 가장 현명한 방법은 무엇일까요?

 

물론 이 문제에는 여러 풀이가 있습니다. 초등 경시 문제 스타일로 자릿수에 1이 총 몇개 들어가는지, 2가 총 몇개 들어가는지, ..., 9가 총 몇개 들어가는지 게산해서 세는 방법도 있고, 이 방법 또한 나쁘지 않은 풀이방법입니다. 하지만 조금 다른 방식으로 머리를 써볼까요?

 

예를 들어서 0에서 4940까지의 자릿수 합을 전부 구해본다고 가정합시다. 0-999까지의 구간과 1000-1999까지의 구간은, 앞에 1이 붙는다는 점을 제외한다면 완전히 같은 내용을 계산합니다. 이건 2000-2999까지의 구간, 3000-3999까지의 구간도 마찬가지입니다. 즉, 0에서 999까지의 구간을 미리 구해놓은 값이 있다면, 0에서 3999까지의 구간 안에 있는 모든 자연수의 자릿수 합은 쉽게 구할 수 있습니다. 4000-4940은 따로 0에서 940까지 구해놓은 값이 있으면 될거고, 이는 재귀로 넘기거나 직접 계산하거나 하면 되겠네요.

 

짜잔! 문제의 자릿수가 하나 줄었습니다. 이런 식으로 자릿수를 계속 줄여나갈 수 있고, 이를 통해서 자릿수 정도의 시간 복잡도만 투자해도 문제를 금방 풀 수 있습니다. 이렇게 각 자릿수의 숫자와 관련이 있는 문제가 나왔을 때, 자릿수 단위로 생각하면서 문제를 해결하는 동적 계획법 기법을 Digit-dp(자릿수 dp)라고 부른다고 하더라고요. 이 방법을 사용한다면 O(log N)정도의 시간 복잡도가 나오겠네요. 이정도라면 20억을 넣어도 뭐 거의 상수입니다. 시간 걱정 안하고 풀 수 있습니다.

 

3. 풀이

코드 복사-붙여넣기 하는 당신이 아름답습니다.

 

d=dict()
d[0]=0
def solve(n):
	if n<0:return 0
	if n in d:return d[n]
	if n<10:
		res=0
		for i in range(n+1):res+=i
		d[n]=res
		return res
	first=n
	ct=0
	res=0
	while first>9:
		first//=10
		ct+=1
	k=solve(10**ct-1)
	for i in range(first):
		res+=i*10**ct+k
	other=n-first*10**ct
	res+=(solve(other)+(other+1)*first)
	d[n]=res
	return res
a,b=map(int,input().split())
print(solve(b)-solve(a-1))