코딩 테스트 스터디/백준

[실버 I] 14888번. 연산자 끼워넣기

남쪽마을밤송이 2022. 8. 8. 21:57

 문제 

https://www.acmicpc.net/problem/14888

 

14888번: 연산자 끼워넣기

첫째 줄에 수의 개수 N(2 ≤ N ≤ 11)가 주어진다. 둘째 줄에는 A1, A2, ..., AN이 주어진다. (1 ≤ Ai ≤ 100) 셋째 줄에는 합이 N-1인 4개의 정수가 주어지는데, 차례대로 덧셈(+)의 개수, 뺄셈(-)의 개수, 

www.acmicpc.net

 사용 언어 

Python3

 풀이 과정 

이 문제는 저번에 프로그래머스 못 풀었던 문제처럼(좀 더 쉽지만) 완전 탐색 방법이 제일 처음 떠올랐다.

그래서 그대로 풀어봤더니 python3에서는 시간초고ㅏ...

혹시 몰라 pypy3로 돌려보니까 진짜 천천히 올라가더니 맞았습니다!!가 나오긴 했다.

이후 이중 for문으로 list 선언한 부분을 list comprehension으로 변경했더니 8ms가 줄어들었다.

# 변경 전
op = []
for i in range(len(op_num)):
    for j in range(op_num[i]):
        op.append(op_list[i])
        
# 변경 후
op = [op_list[i] for i in range(len(op_num)) for _ in range(op_num[i])]

 

찾아보니 백트래킹 방법을 쓰면 python3로도 통과할 수 있는 것 같아 그 방법은 아래에 정리해보겠다.

백트래킹... 요즘 여러번 들은 단어인데 아직 제대로 공부해본 적 없는 친구였는데 permutation(순열)을 이용한 brute force 방법과 비교했을 때 얼마나 차이가 나는지 보기 위해 pypy3와 python3로 각각 돌려본 결과이다.

960ms가 216ms로 엄청 많이 줄어든 것을 확인했다.

python3으로도 pypy3보다 훨씬 좋은 메모리와 시간복잡도를 보이며 깔끔하게 통과했다.

 

 제출 답안 

from sys import stdin
from itertools import permutations
input = stdin.readline

N = int(input())
num = list(map(int, input().split()))
op_num = list(map(int, input().split()))  # +, -, *, /
op_list = ['+', '-', '*', '/']
op = []

for i in range(len(op_num)):
    for j in range(op_num[i]):
        op.append(op_list[i])

maximum = -1e9
minimum = 1e9

def bruteforce():
    global maximum, minimum
    for case in permutations(op, N - 1):
        sum = num[0]
        for k in range(1, N):
            if case[k - 1] == '+':
                sum += num[k]
            elif case[k - 1] == '-':
                sum -= num[k]
            elif case[k - 1] == '*':
                sum *= num[k]
            elif case[k - 1] == '/':
                sum = int(sum / num[k])

        if sum > maximum:
            maximum = sum
        if sum < minimum:
            minimum = sum
            
if __name__ == "__main__":
    bruteforce()
    print(maximum)
    print(minimum)

 

 공부한 내용 

백트래킹

백트래킹은 모든 경우의 수를 전부 고려하는 알고리즘이다.

일종의 탐색 알고리즘인데 DFS와 BFS 두가지로 구현이 가능하지만 일반적으로 DFS로 구현하는게 더 편하다고 한다. 그 이유는 BFS 방식의 경우 큐의 크기가 커질 수 있기 때문이다.

다만 DFS를 쓸 수 없는 경우가 있는데 트리의 깊이가 무한대가 될 때이다. 이 경우에는 BFS를 써야 한다.

 

하지만 일반적인 경우에는 DFS로 모두 통용되니 기본적으로 백트래킹은 DFS로 구현한다고 생각하면 좋다고 한다.

 

백트래킹은 답이 될 수 없는 후보는 더이상 깊게 들어가지않고 되돌아가는 방법을 의미한다.

그러므로 백트래킹은 모든 경우의 수를 탐색하는 브루트포스(brute force) 방법보다 훨씬 더 시간을 절약할 수 있게 된다.

자세한 예시는 출처를 참고한다.

 

정확한 개념 이해는 아래 영상으로 이해했는데 30분짜리이지만 확실하게 백트래킹이 뭔지 이해할 수 있었다.

 

 

출처: https://kyun2da.github.io/2020/08/10/backTracking/

 

 

 다른 풀이 

이 문제에 백트래킹 방법을 적용한 코드들 중 내가 가장 마음에 들었던 코드를 내 스타일대로 변경하며 이해했다.

# python3로 통과
# 백트래킹 사용
from sys import stdin
input = stdin.readline

def backtracking(sum, idx, add, sub, mul, div):
    global maximum, minimum
    if idx == N:
        maximum = max(maximum, sum)
        minimum = min(minimum, sum)
        return

    if add > 0:
        backtracking(sum + nums[idx], idx + 1, add - 1, sub, mul, div)
    if sub > 0:
        backtracking(sum - nums[idx], idx + 1, add, sub - 1, mul, div)
    if mul > 0:
        backtracking(sum * nums[idx], idx + 1, add, sub, mul - 1, div)
    if div > 0:
        backtracking(int(sum / nums[idx]), idx + 1, add, sub, mul, div - 1)

if __name__ == "__main__":
    N = int(input())
    nums = list(map(int, input().split()))
    add, sub, mul, div = map(int, input().split())
    maximum = -1e9
    minimum = 1e9

    backtracking(nums[0], 1, add, sub, mul, div)
    print(maximum)
    print(minimum)

출처 : https://zidarn87.tistory.com/350

 

 다른 사람 풀이 

1. 

# 백트래킹 (Python3 통과, PyPy3도 통과)
import sys

input = sys.stdin.readline
N = int(input())
num = list(map(int, input().split()))
op = list(map(int, input().split()))  # +, -, *, //

maximum = -1e9
minimum = 1e9

def dfs(depth, total, plus, minus, multiply, divide):
    global maximum, minimum
    if depth == N:
        maximum = max(total, maximum)
        minimum = min(total, minimum)
        return

    if plus:
        dfs(depth + 1, total + num[depth], plus - 1, minus, multiply, divide)
    if minus:
        dfs(depth + 1, total - num[depth], plus, minus - 1, multiply, divide)
    if multiply:
        dfs(depth + 1, total * num[depth], plus, minus, multiply - 1, divide)
    if divide:
        dfs(depth + 1, int(total / num[depth]), plus, minus, multiply, divide - 1)


dfs(1, num[0], op[0], op[1], op[2], op[3])
print(maximum)
print(minimum)

출처 : https://velog.io/@kimdukbae/BOJ-14888-%EC%97%B0%EC%82%B0%EC%9E%90-%EB%81%BC%EC%9B%8C%EB%84%A3%EA%B8%B0-Python 

 

2.

def dfs(index,res):
    global minAns
    global maxAns
    # 계산의 끝에 도달했을 때 최댓값과 최솟값이 될 수 있는지 판단한다.
    if index==N-1:
        if minAns > res:
            minAns = res
        if maxAns < res:
            maxAns = res
        return res
    # 백트래킹 DFS로 순회
    for i in range(4):
        temp = res
        if operator[i]==0:
            continue
        if i==0:
            res+=numArr[index+1]
        elif i==1:
            res-=numArr[index+1]
        elif i==2:
            res*=numArr[index+1]
        else:
            if res<0:
                res = abs(res)//numArr[index+1]*-1
            else:
                res //=numArr[index+1]
        operator[i] -= 1
        dfs(index+1,res)
        operator[i] += 1
        res = temp

N = int(input())
numArr = list(map(int,input().split()))
operator = list(map(int,input().split()))
minAns = float('Inf')
maxAns = float('-Inf')

dfs(0,numArr[0])
print(maxAns)
print(minAns)

출처 : https://kyun2da.github.io/2020/08/11/putOperator/