위 사진을 보면 배열이 있고 검은색 선(세 개의 쿼리)이 정렬되어있는 사진이다. 빨간색으로 색칠된 부분은 이미 이전에 계산해놓은 값을 가져다가 쓰기 때문에 따로 계산해 줄 필요가 없으므로 이러한 이점을 이용해서 푸는 알고리즘이다.
시간 복잡도는 O((Q + N)√N)이다.
구간합과 비슷한 문제를 예제로 구현해보자
문제 링크 : 8462_배열의 힘
자연수로 이루어진 배열이 주어지고 부분 배열의 힘이라는 것을 구한다. 이때 힘은 부분 배열에 있는 자연수 * 해당 자연수의 빈도 수 ** 2를 존재하는 자연수 모두 더해주면 된다.
이 문제는 update 연산이 없으므로 쿼리 순서를 마음대로 바꿀 수 있고 사실 구하는 방법만 다르지 구간합과 다를 것 없다.
쿼리를 받고 이 쿼리가 몇 번 째 답인지 알아야 하므로 index만 뒤에 붙여서 저장하고 왼쪽 기준으로 sqrt한 값과 오른쪽 값으로 오름차순으로 정렬한다.
cnt = [0] * 1000001 # 각 숫자 빈도수
ans = [0] * T # 정답 담을 배열
res = 0 # 현재 힘
for i in range(T):
s, e = map(int, input().split())
query.append((s, e, i))
query.sort(key=lambda x: (int(x[0] ** .5), x[1]))
아까 투 포인터와 닮았다고 말했고 현재 쿼리 구간을 다음 쿼리 구간으로 왼쪽, 오른쪽 포인터 지점을 변경시켜주는 것이 필요하다. 이때 움직일 때마다 힘을 바로 계산해주면 된다. 이렇게 해주지 않으면 나중에 힘을 계산할 때 존재할 수 있는 자연수의 배열을 linear search 해야 하기 때문이다.
def insert(val):
global res
res -= cnt[val] ** 2 * val
cnt[val] += 1
res += cnt[val] ** 2 * val
def delete(val):
global res
res -= cnt[val] ** 2 * val
cnt[val] -= 1
res += cnt[val] ** 2 * val
def move(ss, se, es, ee):
for i in range(ss, es):
delete(arr[i])
for i in range(se, ee, -1):
delete(arr[i])
for i in range(ss - 1, es - 1, -1):
insert(arr[i])
for i in range(se + 1, ee + 1):
insert(arr[i])
input = __import__('sys').stdin.readline
def insert(val):
global res
res -= cnt[val] ** 2 * val
cnt[val] += 1
res += cnt[val] ** 2 * val
def delete(val):
global res
res -= cnt[val] ** 2 * val
cnt[val] -= 1
res += cnt[val] ** 2 * val
def move(ss, se, es, ee):
for i in range(ss, es):
delete(arr[i])
for i in range(se, ee, -1):
delete(arr[i])
for i in range(ss - 1, es - 1, -1):
insert(arr[i])
for i in range(se + 1, ee + 1):
insert(arr[i])
if __name__ == '__main__':
N, T = map(int, input().split())
arr = [0] + list(map(int, input().split()))
query = []
cnt = [0] * 1000001
ans = [0] * T
res = 0
for i in range(T):
s, e = map(int, input().split())
query.append((s, e, i))
query.sort(key=lambda x: (int(x[0] ** .5), x[1]))
l = r = query[0][0]
r -= 1
for i in range(T):
s, e, idx = query[i]
move(l, r, s, e)
l, r = s, e
ans[idx] = res
for i in ans:
print(i)