Minerva: 언어 모델을 활용한 정량적 추론 문제 해결

Minerva는 수학, 과학, 공학 분야의 정량적 추론 문제 해결에 특화된 대규모 언어 모델입니다. 일반 자연어 데이터로 사전 학습된 PaLM 모델을 기반으로, arXiv 논문 및 수학 콘텐츠가 포함된 웹페이지로 구성된 기술 데이터셋으로 추가 학습되었습니다. 이 모델은 외부 도구 없이도 단계별 풀이 과정을 생성하며, MATH, GSM8k와 같은 기술 벤치마크에서 최고 수준의 성능을 달성했습니다. Minerva는 자연어 설명과 정확한 LaTeX 수식을 혼합하여 복잡한 문제의 해답을 제시하는 능력을 보여줍니다. 논문 제목: Solving Quantitative Reasoning Problems with Language Models

Lewkowycz, Aitor, et al. "Solving quantitative reasoning problems with language models." Advances in neural information processing systems 35 (2022): 3843-3857.

Solving Quantitative Reasoning Problems with Language Models

Abstract

Language model들은 **자연어 이해(natural language understanding)**를 요구하는 광범위한 task에서 놀라운 성능을 달성해왔다. 그럼에도 불구하고, 대학 수준의 수학, 과학, 공학 문제 해결과 같이 **정량적 추론(quantitative reasoning)**을 요구하는 task에서는 state-of-the-art 모델들이 일반적으로 어려움을 겪어왔다. 이러한 격차를 줄이기 위해 우리는 Minerva를 소개한다. Minerva는 일반 자연어 데이터로 사전학습된 후, 기술 콘텐츠로 추가 학습된 대규모 language model이다. 이 모델은 외부 도구의 도움 없이도 기술 벤치마크에서 state-of-the-art 성능을 달성한다. 또한 우리는 정량적 추론을 요구하는 200개 이상의 학부 수준 물리, 생물학, 화학, 경제학 및 기타 과학 문제에 대해 모델을 평가했으며, 그 결과 모델이 거의 3분의 1에 해당하는 문제를 정확하게 해결할 수 있음을 확인했다.

1 Introduction

인공신경망은 컴퓨터 비전, 음성 인식, 오디오 및 이미지 생성, 번역, 게임 플레이, 로봇 공학 등 다양한 분야에서 놀라운 성공을 거두었다. 특히, **대규모 언어 모델(large language models)**은 상식 추론, 질문 응답, 요약 등 다양한 자연어 task에서 뛰어난 성능을 달성했다 (Raffel et al., 2019; Brown et al., 2020; Rae et al., 2021; Smith et al., 2022; Chowdhery et al., 2022). 그러나 이러한 모델들은 수학, 과학, 공학 문제 해결과 같이 정량적 추론(quantitative reasoning)을 요구하는 task에서는 어려움을 겪어왔다 (Hendrycks et al., 2021; Cobbe et al., 2021).

정량적 추론 문제는 언어 모델의 여러 능력을 시험한다는 점에서 흥미로운 응용 분야이다. 이러한 문제들은 해결자가 자연어 입력을 정확하게 파싱하고, 문제와 관련된 세계 지식을 잠재적으로 상기하며, 제공된 정보에 알고리즘 또는 일련의 계산을 적용하여 올바른 해답에 도달하도록 요구한다. 또한, 해결자가 정확한 수학적 token 시퀀스를 파싱하고 생성할 수 있어야 하며, 기호 및 수치 조작을 통해 token에 계산 절차를 적용할 수 있어야 한다. 마지막으로, 이러한 문제들은 과학 및 기술 분야에서 인간의 작업을 지원하는 데 유용한 강력한 정량적 추론 해결책을 향한 연구의 시험대가 된다.

이전 연구에서는 대규모 언어 모델이 domain-specific 데이터셋으로 학습한 후 수학 및 프로그래밍 질문에서 인상적인 성능을 달성했음을 보여주었다 (Chen et al., 2021; Austin et al., 2021;

Question: A line parallel to y=4x+6y=4 x+6 passes through (5,10)(5,10). What is the yy-coordinate of the point where this line crosses the yy-axis?

Target: Since the line is parallel to y=4x+6y=4 x+6, we know that its slope is 4 . Thus the equation of the line is y=4x+by=4 x+b where bb is the yy-intercept. Plugging (5,10) into this gives 10=45+bb=1010=4 \cdot 5+b \Rightarrow b=-10 which is what we wanted.

Model output: The line is parallel to y=4x+6y=4 x+6, which means that it has the same slope as y=4x+6y=4 x+6. Since the line passes through (5,10)(5,10), we can use the point-slope form of a line to find the equation of the line:

y10=4(x5)y10=4x20y=4x10\begin{aligned} y-10 & =4(x-5) \\ y-10 & =4 x-20 \\ y & =4 x-10 \end{aligned}

Now that we have the equation of the line, we can find the yy-coordinate of the point where the line crosses the yy-axis by substituting x=0x=0 into the equation:

y=4010=10y=4 \cdot 0-10=-10

Figure 1: MATH 데이터셋 (Algebra 4)의 질문과 Minerva 62B가 생성한 해답. 이 해답은 256개의 샘플 중 다수결 투표로 선정되었다. 이 질문과 다른 MATH 질문들은 동일한 4-shot prompt로 제시되었다. 모델은 선 방정식을 작성하고, 이를 단순화하며, xx를 대입하고, yy를 풀어 해답을 도출한다. 모델은 자연어 설명과 함께 유효한 ETEX\mathrm{ET}_{\mathrm{E}} \mathrm{X}를 출력한다. 모델이 취한 접근 방식은 ground truth target에서 사용된 방식과 다르다.

Drori et al., 2021). 본 연구에서는 이 접근 방식을 정량적 추론 문제에 적용하며, 이 문제에서는 모델이 외부 도구에 의존하지 않고 완전하고 자체 포함된 해답을 제공해야 한다. 이러한 task에는 수학 word problem, 경쟁 수학 평가, 그리고 과학 및 공학의 많은 문제들이 포함된다.

1.1 Our Contribution

우리는 다양한 정량적 추론(quantitative reasoning) task에서 강력한 성능을 달성하는 언어 모델인 Minerva를 소개한다. 우리의 모델은 자연어로 표현된 과학 및 수학 질문을 처리하고, 정확한 ETEX\mathrm{ET}_{\mathrm{E}} \mathrm{X} 표기법을 사용하여 단계별 솔루션을 생성할 수 있다. Figure 1, 2, 3은 Minerva가 수학 및 물리학 질문에 대해 생성한 응답 예시를 보여준다.

Minerva는 PaLM 일반 언어 모델 (Chowdhery et al., 2022)을 기반으로 하며, 과학 및 수학 데이터를 포함하는 고품질 데이터셋으로 추가 학습되었다. 특히, 우리는 8B, 62B, 540B 파라미터의 사전학습된 모델로 시작하여, 이를 우리의 기술 콘텐츠 데이터셋으로 계속 학습시켰다. 그 결과, MATH (Hendrycks et al., 2021), GSM8k (Cobbe et al., 2021), 그리고 MMLU (Hendrycks et al., 2020) 데이터셋의 STEM 하위 집합에서 state-of-the-art 성능을 달성했다. 이 모든 데이터셋은 자연어로 표현된 수학 및 과학 질문을 포함한다. 특히, 우리 모델은 이러한 평가 데이터셋에 명시적으로 학습되지 않았음에도 few-shot 설정에서 견고한 성능을 보인다.

본 논문의 주요 참신성은 자연어와 방정식 및 다이어그램과 같은 형식적인 수학 언어의 올바른 사용을 병치(juxtapose)한 대규모 학습 데이터셋이다. 이 데이터는 arXiv preprint 서버와 웹페이지에서 수집되었으며, 수학적 콘텐츠의 손실을 최소화하기 위해 신중하게 처리되었다. 이 연구는 새로운 state-of-the-art 결과를 확립하는 것 외에도, 데이터 품질과 모델 크기를 증가시킴으로써 정량적 추론 벤치마크에서 달성 가능한 성능에 대한 새로운 baseline을 확립한다.

정량적 추론 벤치마크의 범위를 더욱 넓히기 위해, 우리는 MIT의 OpenCourseWare (OCW)에서 200개 이상의 학부 수준 과학 및 수학 질문으로 구성된 데이터셋을 구축했다. 이는 순수한 수학적 환경을 넘어 chain-of-thought 맥락에서 우리 모델의 정량적 추론 능력을 측정하는 척도를 제공한다.

Question: Assume that the variance of the first nn natural numbers is 10 , and the variance of the first mm even natural numbers is 16 . Compute m+nm+n.

Model output: Let σ2\sigma^{2} be the variance of the first mm even natural numbers, and τ2\tau^{2} be the variance of the first nn natural numbers. Then, we have τ2=n2112\tau^{2}=\frac{n^{2}-1}{12} and σ2=m213\sigma^{2}=\frac{m^{2}-1}{3}. We solve the system of equations:

n2112=10m213=16\begin{aligned} & \frac{n^{2}-1}{12}=10 \\ & \frac{m^{2}-1}{3}=16 \end{aligned}

This translates into the system

n2=121m2=48+1=49\begin{aligned} n^{2} & =121 \\ m^{2} & =48+1=49 \end{aligned}

Therefore, n=11n=11 and m=7m=7, so n+m=n+m= 18.

Question: For every a,b,baa, b, b \neq a prove that

a2+b22>(a+b2)2\frac{a^{2}+b^{2}}{2}>\left(\frac{a+b}{2}\right)^{2}

Model output:

a2+b22>(a+b2)2a2+b22>a2+b2+2ab4a2+b2>a2+b2+2ab22a2+2b2>a2+b2+2aba2+b2>2aba2+b22ab>0(ab)2>0\begin{array}{r} \frac{a^{2}+b^{2}}{2}>\left(\frac{a+b}{2}\right)^{2} \\ \Longleftrightarrow \frac{a^{2}+b^{2}}{2}>\frac{a^{2}+b^{2}+2 a b}{4} \\ \Longleftrightarrow a^{2}+b^{2}>\frac{a^{2}+b^{2}+2 a b}{2} \\ \Longleftrightarrow 2 a^{2}+2 b^{2}>a^{2}+b^{2}+2 a b \\ \Longleftrightarrow a^{2}+b^{2}>2 a b \\ \Longleftrightarrow a^{2}+b^{2}-2 a b>0 \\ \Longleftrightarrow(a-b)^{2}>0 \end{array}

which is true, because the square of a real number is positive.

Figure 2: Minerva 540B의 솔루션이 포함된 질문 예시. 왼쪽: 매년 약 2백만 명의 인도 고등학생들이 공학 및 유사 분야를 공부하기 위해 응시하는 Joint Entrance Examination Main Math 2020 시험의 질문. 오른쪽: 매년 약 27만 명의 고등학생들이 응시하는 **폴란드 전국 수학 시험 (2022년 5월)**의 질문.

자연어로 표현된 정량적 추론 문제 해결은 활발한 연구 분야였다 (Koncel-Kedziorski et al., 2015; Hosseini et al., 2014). scratchpad (Nye et al., 2021) 또는 chain-of-thought (Wei et al., 2022) 솔루션을 사용하여 language model에 prompt를 제공하면, 모델이 이전에 보지 못한 문제에 대한 단계별 솔루션을 출력하도록 유도할 수 있다. GSM8k 연구 (Cobbe et al., 2021)는 모델 출력을 재순위화(rerank)하기 위해 verifier를 훈련하는 것이 성능 향상으로 이어질 수 있음을 보여주었다. GSM8k의 원래 버전에는 대수 계산을 위한 특수 구문이 포함되어 있었고, 이는 계산기에 의해 처리되었다. 본 연구에서는 외부 도구에 접근하지 않는 자율적인(self-contained) 모델에 초점을 맞춘다.

생성(generative) task에서 language model을 평가하는 표준 방법은 문제당 하나의 솔루션을 greedy하게 샘플링하는 것이다. 최근 연구들 (Chen et al., 2021; Li et al., 2022; Thoppilan et al., 2022; Wang et al., 2022)은 문제당 여러 솔루션을 샘플링한 다음, 이를 최종 답변으로 필터링하는 것이 유리하다는 것을 보여주었다. 우리는 다수결 투표(majority voting) (Wang et al., 2022)가 greedy decoding보다 성능을 크게 향상시킨다는 것을 발견했다.

Drori et al. (2021)의 연구에는 MATH 데이터셋의 90개 문제 하위 집합에 대한 OpenAI의 최신 공개 language model인 davinci-002의 평가가 포함되어 있다. 질문 하위 집합에 초점을 맞추고, 질문 형식 지정 방식에 변경이 있었기 때문에, 우리의 결과를 Drori et al. (2021)의 결과와 직접 비교하기는 어렵다. Section 3에서는 동일한 실험 조건에서 OpenAI davinci-002와 우리의 모델을 비교한다.

Question: 균일한 고체 원판이 정지 상태에서 미끄러지지 않고 경사면을 굴러 내려간다. 일정 시간 후, 원판의 전체 운동 에너지 중 회전 운동 에너지가 차지하는 비율은 얼마인가? 다음 중 올바른 선택은 무엇인가? 답을 설명하시오. (A) 1/41 / 4 (B) 1/31 / 3 (C) 1/21 / 2 (D) 2/32 / 3

Model output: 원판은 병진 운동 에너지와 회전 운동 에너지를 모두 가진다. 원판의 질량 중심 속도를 vv, 원판의 관성 모멘트를 II, 원판의 질량을 MM, 원판의 반지름을 RR이라고 하자. 원판의 병진 운동 에너지는 다음과 같다.

Kt=12Mv2K_{t}=\frac{1}{2} M v^{2}

원판의 회전 운동 에너지는 다음과 같다.

Kr=12Iω2=12(12MR2)(vR)2=14Mv2K_{r}=\frac{1}{2} I \omega^{2}=\frac{1}{2}\left(\frac{1}{2} M R^{2}\right)\left(\frac{v}{R}\right)^{2}=\frac{1}{4} M v^{2}

따라서 회전 운동 에너지와 전체 운동 에너지의 비율은 다음과 같다.

KrKt+Kr=14Mv212Mv2+14Mv2=14Mv234Mv2=13\frac{K_{r}}{K_{t}+K_{r}}=\frac{\frac{1}{4} M v^{2}}{\frac{1}{2} M v^{2}+\frac{1}{4} M v^{2}}=\frac{\frac{1}{4} M v^{2}}{\frac{3}{4} M v^{2}}=\frac{1}{3}

Final Answer: 최종 답은 (B)이다.

Question: 한 입자가 시간 tt(3sint4,3cost4)\left(3 \sin \frac{t}{4}, 3 \cos \frac{t}{4}\right)에 위치한다. 단위 시간당 거리 단위로 측정된 입자의 속도를 구하시오.

Target: 매개변수 곡선 (x,y)=(3sint4,3cost4)(x, y)=\left(3 \sin \frac{t}{4}, 3 \cos \frac{t}{4}\right)는 반지름 3인 원을 나타낸다. 또한, t=8πt=8 \pi일 때 한 바퀴를 완전히 돈다. (figure) 따라서 입자의 속도는 6π8π=\frac{6 \pi}{8 \pi}=| 34\frac{3}{4} | | :--- | | . | .

Model output: 속도는 속도 벡터의 크기이다. 따라서 먼저 속도 벡터를 찾아야 하며, 이는 위치 벡터의 미분이다. 우리는 다음을 얻는다.

v=ddt(3sint4,3cost4)=3(cost4,sint4)14=(34cost4,34sint4).\mathbf{v}=\frac{d}{d t}\left(3 \sin \frac{t}{4}, 3 \cos \frac{t}{4}\right)=3\left(\cos \frac{t}{4},-\sin \frac{t}{4}\right) \frac{1}{4}=\left(\frac{3}{4} \cos \frac{t}{4},-\frac{3}{4} \sin \frac{t}{4}\right) .

따라서 속도 벡터의 크기는 다음과 같다.

v=(34cost4)2+(34sint4)2=916cos2t4+916sin2t4=34cos2t4+sin2t4=34\begin{aligned} |\mathbf{v}| & =\sqrt{\left(\frac{3}{4} \cos \frac{t}{4}\right)^{2}+\left(-\frac{3}{4} \sin \frac{t}{4}\right)^{2}}=\sqrt{\frac{9}{16} \cos ^{2} \frac{t}{4}+\frac{9}{16} \sin ^{2} \frac{t}{4}} \\ & =\frac{3}{4} \sqrt{\cos ^{2} \frac{t}{4}+\sin ^{2} \frac{t}{4}}=\frac{3}{4} \end{aligned}

따라서 입자의 속도는 | 34\frac{3}{4} | | :--- | .

Figure 3: Minerva 솔루션이 포함된 물리 문제. 상단: MMLU (College Physics)의 예시 질문과 Minerva 540B가 생성한 솔루션. 모델은 움직이는 물체의 물리적 특성을 설명하기 위한 표기법을 도입하고, 관련 물리적 관계를 작성하며, 고체 원판의 관성 모멘트를 상기하고, 원하는 에너지 비율을 올바르게 계산한다. 하단: MATH 데이터셋 (Precalculus 4)의 질문. 모델은 위치 벡터의 미분을 올바르게 취하여 속도를 계산한다. 그런 다음 결과 벡터의 norm을 계산하고, 삼각 함수 항등식을 사용하여 최종 수치 답을 얻는다. Minerva는 ground truth 솔루션과는 완전히 다른 접근 방식을 취한다.

코드 생성(Code generation). 코드 생성 모델을 수학 문제에 적용하는 것은 활발한 탐구 영역이었다. PaLM (Chowdhery et al., 2022)은 훈련 데이터셋에 코드를 포함하는 대규모 language model이 GSM8k의 코드 버전에서 좋은 성능을 달성할 수 있음을 보여주었다. 또한, Codex 모델 (Chen et al., 2021)은 MATH 문제에 대한 코드 솔루션을 생성할 수 있다 (Drori et al., 2021). 이러한 솔루션은 종종 방정식 풀이나 극한 계산과 같은 수학적 연산을 수행하기 위해 외부 라이브러리에 의존한다. 이는 모델 자체의 추론 능력에만 의존하여 답에 도달하는 모델의 능력을 직접적으로 탐색하는 우리의 접근 방식과는 상호 보완적인 접근 방식이다.

형식 수학(Formal mathematics). 수학은 자연어에 기반한 학문으로 발전했지만, 그 공리적 기초는 수학적 사고를 시뮬레이션하는 것을 가능하게 한다. 이는 Coq (development team, 2022), Isabelle (Wenzel et al., 2008), HOL4 (Harrison, 1996), Lean (de Moura et al., 2015), Metamath (Megill and Wheeler, 2019), Mizar (Grabowski et al., 2010)과 같이 컴퓨터를 사용하여 논리적 및 수학적 사고의 시뮬레이션을 용이하게 하는 특수 프로그래밍 언어를 사용하여 달성할 수 있다. E (Schulz, 2013), leanCoP (Otten, 2008), Vampire (Kovács and Voronkov, 2013)와 같은 proof assistant 및 automated theorem prover의 자동화 연구는 머신러닝 방법과의 통합을 통해 상당한 이점을 얻었다 (Alemi et al., 2016; Goertzel et al., 2021; Li et al., 2021; Polu and Sutskever, 2020; Kaliszyk et al., 2018).

형식적 및 합성 수학 문제에 적용된 language model. 이전 연구에서는 수학적 표현을 예측하도록 language model을 훈련했다 (Rabe et al., 2021; Li et al., 2021; Polu and Sutskever, 2020; Wu et al., 2021; Han et al., 2022; Polu et al., 2022; Jiang et al., 2022; Wu et al., 2022). 이러한 예측 모델은 Polu and Sutskever (2020)에서 수행된 것처럼 proof search를 안내하는 데 사용될 수 있다. 대규모 language model은 자연어 모델링에 탁월하지만, 형식 언어의 경우 GNN과 같이 주어진 수학 공식의 그래프 구조에 대한 정보를 유지하는 데 도움이 되는 모델이 여전히 매우 경쟁력이 있다.

자연어 학문으로서의 수학 모델링. 새로운 벤치마크 데이터셋 (Hendrycks et al., 2021; Welleck et al., 2021)은 더 고급 수학 주제를 다룬다. 이 영역에서 language model은 다른 종류의 모델들과 제한된 경쟁에 직면해 있다.

2 Training and Evaluation

2.1 Mathematical Training Dataset

우리의 모델은 수학적 콘텐츠를 위해 필터링된 웹페이지arXiv preprint 서버에 제출된 논문에서 수집된 385억 개의 토큰으로 구성된 데이터셋으로 학습되었다. 또한, 이 데이터셋에는 일반 자연어 데이터도 포함되어 있는데, 이는 PaLM 사전학습에 사용된 것과 동일한 데이터셋이다.
우리의 수학 웹페이지 데이터셋은 MathJax 형식의 수학 표현식을 포함하는 페이지들을 수집하여 구축되었다. 이 페이지들은 대부분의 HTML 태그를 제거하지만, ETEX 기호 및 서식을 포함한 수학적 표기법은 보존하는 클리닝 과정을 거쳤다. 그 결과, 학습 중에 eπi+1=0e^{\pi i}+1=0 또는 E=mc2E=m c^{2}와 같은 수학 공식들이 모델에 온전히 제시된다. 이러한 절차는 모델이 계산 및 기호 조작을 요구하는 task에서 우수한 성능을 발휘할 수 있도록 한다. Table 1은 학습 데이터셋의 세부 구성을 보여준다. 더 자세한 내용은 Appendix B를 참조하라.

Table 1: 기술 학습 데이터셋 내 각 소스별 데이터 비율 및 토큰 수.
General Natural Language 데이터셋은 모델 사전학습에 사용된 데이터셋의 부분집합이다.

Data sourceProportion of dataTokensPresent during pretraining
Math Web Pages47.5%47.5 \%17.5 BNo
arXiv47.5%47.5 \%21.0 BNo
General Natural Language Data5%5 \%>100 B>100 \mathrm{~B}Yes

2.2 Models and Training Procedure

우리의 접근 방식은 PaLM 사전학습된 decoder-only Transformer language model인 Chowdhery et al. (2022)을 기반으로 시작하여, **autoregressive objective를 사용하여 수학 데이터셋으로 추가 학습(fine-tuning)**하는 것이다. Table 2에는 주요 모델 및 학습 하이퍼파라미터가 포함되어 있다. 540B 파라미터를 가진 가장 큰 모델은 26B 토큰으로 fine-tuning되었다. 이 모델은 8B 및 62B 모델에 비해 매우 적게 학습되었음에도 불구하고 여전히 우수한 성능을 달성한다. 추가 세부 정보는 Appendix C에서 확인할 수 있다.

Table 2: 모델 아키텍처 및 계속된 학습 하이퍼파라미터.
모델 학습은 사전학습된 PaLM 모델에서 재개되었으며, 인용된 Steps 수는 우리의 기술 데이터셋에 대한 계속된 학습만을 의미한다.

ModelLayersHeadsdmodel d_{\text {model }}ParametersStepsTokens
Minerva 8B321640968.63 B624 k164 B
Minerva 62B6432819262.50 B416 k109 B
Minerva 540B1184818432540.35 B399 k26 B

2.3 Evaluation Datasets

우리는 주로 few-shot 평가에 초점을 맞추지만, fine-tuned 평가에 대한 논의는 Appendix E.3을 참조하라. 평가를 위해 입력은 왼쪽에서부터 1024 토큰으로 잘라내고(truncate), 모델은 최대 512 토큰을 생성하도록 설정한다. 문제당 한 번 샘플링할 때는 greedy 샘플링을 사용한다. 문제당 여러 번 샘플링할 때는 **nucleus sampling (Holtzman et al., 2019)을 사용하며, temperature T=0.6,p=0.95T=0.6, p=0.95**로 설정한다.
생성(generative) task의 경우, 모델은 chain-of-thought 답변을 생성하고 최종 답변을 명확히 구분한다. 우리는 최종 답변이 정답과 일치하면, 그 앞에 오는 chain-of-thought의 품질과 관계없이 정답으로 평가한다. 정확성 평가를 위해 최종 답변을 파싱하고 SymPy 라이브러리 (Meurer et al., 2017)를 사용하여 비교한다. 이는 1/31 / \sqrt{3}3/3\sqrt{3} / 3처럼 수학적으로 동등한 답변을 올바르게 식별하기 위함이다. 자세한 내용은 Appendix D.1을 참조하라.

우리가 중점적으로 다루는 기존 데이터셋은 다음과 같다:

  • MATH: Hendrycks et al. (2021)이 제시한 12,000개의 중고등학교 수학 문제 데이터셋이다. 문제 설명은 ATEX\mathrm{AT}_{\mathrm{E}} \mathrm{X}로 작성되어 있다. 우리는 고정된 4-shot prompt로 모델에 입력한다 (Appendix D.2에 목록화되어 있음). 이 prompt에는 정답 길이가 너무 길지 않은 학습 데이터셋의 무작위 예시 4개가 포함되어 있다.
  • GSM8k: Cobbe et al. (2021)이 제시한 중학교 수학 word problem 데이터셋이다. 모델은 Wei et al. (2022)의 chain-of-thought prompt를 사용하여 평가된다. GSM8k에서 평가된 이전 모델들은 외부 계산기를 사용했지만, 본 연구에서 우리 모델은 어떠한 외부 도구에도 접근할 수 없다.

Figure 4: MATH 및 MMLU-STEM의 하위 주제별 성능. Minerva는 두 데이터셋 모두에서 state-of-the-art 결과를 달성한다. maj1@k는 각 문제에 대해 kk개의 샘플을 생성하고 가장 흔한 답변만 선택한 평가를 나타낸다 (Wang et al., 2022). MATH의 경우, Minerva 8B 및 62B는 k=256k=256, 540B는 k=64k=64이다. MMLU-STEM의 경우 k=16k=16이다. davinci-002는 OpenAI의 최신 공개 언어 모델이다.

  • MMLU-STEM: MMLU 데이터셋 (Hendrycks et al., 2020)의 하위 집합으로, 과학, 기술, 공학, 수학 (STEM) 분야에 초점을 맞춘다. 원본 버전의 경우, 각 task에 대해 개발 세트에서 5-shot prompt를 사용한다. 또한 이 task에 대해 chain-of-thought prompting도 고려하는데, 이때는 단계별 솔루션이 포함된 예시로 모델에 prompt를 제공한다. 수학적 추론이 필요한 주제에는 MATH prompt의 객관식 버전을 사용하고, 나머지 주제에는 표준 5-shot prompt에 단계별 솔루션을 추가한다. 자세한 내용은 Appendix G를 참조하라.

2.4 Undergraduate-Level STEM Problems

Minerva의 과학적 추론 능력을 평가하기 위해, 우리는 학부 수준의 STEM 문제들을 수집했으며, 이들 대부분은 다단계 추론을 포함한다. 본 논문에서는 이 문제들을 OCWCourses라고 부른다. 우리는 MIT에서 공개적으로 제공하는 강의 자료(OpenCourseWare)를 활용하여, "고체 화학", "정보와 엔트로피", "미분 방정식", "특수 상대성 이론" 등의 과목에서 **자동으로 검증 가능한 해답(숫자 또는 SymPy를 통한 기호 검증 가능)**을 가진 문제들을 수집했다. 이 문제들은 계약자들에 의해 **자체 완결적(self-contained)**이며 명확하게 구분된 최종 답안을 가지도록 처리되었다. 증명이나 개방형 단답형 질문을 요구하는 문제는 포함되지 않았다. 총 272개의 문제를 선별했으며, 이 중 191개는 숫자 해답을, 81개는 기호 해답을 가진다. Appendix F에서는 각 과목의 기여도와 이러한 강의 자료를 language model이 처리하기에 적합한 형식으로 변환하는 과정을 자세히 설명한다. 또한 모든 문제의 텍스트를 제공한다. 우리는 이 문제들을 향후 논문에서 자세히 다룰 오픈 소스 데이터셋의 일부로 공개할 계획이다.

2.5 Inference-Time Techniques

우리는 k>1k>1개의 해답을 샘플링(0이 아닌 temperature 사용)하고 다수결 투표(majority voting)를 통해 하나를 선택함으로써 greedy decoding보다 훨씬 뛰어난 성능을 달성할 수 있음을 발견했다 (Wang et al., 2022). 이 방법은 예측들을 최종 답변에 따라 그룹화하고 가장 흔한 답변을 선택하는 방식으로 이루어진다. 우리는 이를 Li et al. (2022)에 따라 maj1@k라고 부른다. 이 알고리즘의 변형인 majn@k는 가장 흔한 nn개의 답변을 선택하는 방식이다.

Table 3: 여러 정량적 추론 데이터셋에 대한 모델 성능. 다수결 투표(majority voting)의 경우, MATH에는 k=256k=256 (540B 모델의 경우 64)개의 샘플을, OCWCourses에는 k=64k=64개의 샘플을, GSM8k에는 k=100k=100 (540B 모델의 경우 40)개의 샘플을, MMLU-STEM에는 k=16k=16개의 샘플을 사용했다. PaLM GSM8k 결과는 계산기를 사용하지 않았으며 (Chowdhery et al., 2022)에서 보고되었다. 우리는 최근 모델에 대한 공개된 결과가 없는 데이터셋을 OpenAI davinci-002에서 평가했다. MMLU-STEM이 객관식(multiple choice) task임에도 불구하고, 모델이 최종 답변 전에 추론 과정(rationale)을 생성하도록 prompt를 구성하고, 여러 번 샘플링한 다음, 최종 답변에 대해 다수결 투표를 적용하여 majority vote를 사용할 수 있다. 위첨자는 이전 연구에서 인용된 결과를 나타낸다: a{ }^{a} GPT-2 Hendrycks et al. (2021), b{ }^{b} PaLM 540B maj1@40 Wang et al. (2022), c{ }^{c} Chinchilla Hoffmann et al. (2022).

MATHOCWCoursesGSM8kMMLU-STEM
PaLM 8B1.5%1.5%4.1%22.0%
Minerva 8B14.1%7.7%16.2%35.6%
Minerva 8B, maj1@k25.4%12.5%28.4%43.4%
PaLM 62B4.4%5.9%33.0%39.1%
Minerva 62B27.6%12.9%52.4%53.9%
Minerva 62B, maj1@k43.4%23.5%68.5%63.5%
PaLM 540B8.8%7.1%56.5%58.7%
Minerva 540B33.6%17.6%58.8%63.9%
Minerva 540B, maj1@k50.3%30.8%78.5%75.0%
OpenAI davinci-00219.1%14.8%--
Published SOTA6.9%a6.9 \%^{a}-74.4%b74.4 \%^{b}54.9%c54.9 \%^{c}

직관적으로, 다수결 투표가 성능을 향상시키는 이유는 질문에 잘못 답하는 방법은 많지만, 올바르게 답하는 방법은 일반적으로 매우 적기 때문이다.

다수결 투표를 pass@k와 비교해보자. pass@k는 kk개의 샘플 중 단 하나라도 task를 해결하면 성공으로 간주하는 방식이다. pass@k 성능에 대한 자세한 내용은 Section 4.2를 참조하라. Appendix E.1에서는 다양한 metric에 대해 성능이 kk에 어떻게 의존하는지 보고한다. 우리는 pass@k는 kk가 증가함에 따라 계속해서 성능이 향상되는 반면, 다수결 투표 성능은 더 빠르게 포화됨을 발견했다. MATH의 경우 k=64k=64에서, GSM8k의 경우 k=16k=16에서 큰 kk 정확도의 97%가 달성된다. 이는 다수결 투표가 모델링된 분포에서 가장 흔한 답변을 선택하고, 이 추정치의 오차가 kk가 증가함에 따라 감소하기 때문일 가능성이 높다. 이는 분포의 꼬리 부분에서 성능 향상이 발생하는 pass@k와는 대조적이며, pass@k는 kk가 증가함에 따라 계속해서 향상될 수 있다.

Log-likelihood는 샘플을 재순위화(rerank)하는 데 사용될 수 있는 또 다른 metric이다. 우리는 다수결 투표가 log-likelihood 재순위화보다 훨씬 더 나은 성능을 보인다는 것을 발견했다 (Appendix E.2 참조).

3 Results

Table 3는 Section 2.3에서 설명된 평가 데이터셋에 대한 Minerva 모델 및 다른 모델들의 결과를 요약한다. Figure 4는 MATH 데이터셋 결과를 하위 주제별로 분류하여 보여준다. MMLU 평가의 경우, 별도로 명시되지 않는 한, 각 주제별 표준 5-shot prompt를 사용하고 가장 높은 점수를 받은 답변을 선택하여 성능을 측정한다. MMLU를 다수결 투표(majority voting) 방식으로 평가할 때는 chain-of-thought prompt를 사용하여 k=16k=16개의 모델 답변을 샘플링한다.

모델 출력 샘플은 Figure 1, 2, 3에 제시되어 있으며, 추가 출력 샘플은 Appendix에 수록되어 있다. 또한, 우리는 Minerva 62B를 폴란드 전국 수학 시험(National Math Exam in Poland)에서 평가했으며, 그 결과 57%의 점수를 달성했다. 이 점수는 2021년 전국 평균 점수와 일치한다 (CKE, 2021, p. 23). 540B 모델은 65%를 달성했다.

우리는 OpenAI의 최신 공개 언어 모델인 davinci-002에 대한 결과도 포함했으며, 이는 OpenAI API를 사용하여 공식 권장 온도(T=0.2)로 설정하여 평가되었다. 학습 데이터, 규모, 추론 기법의 조합은 우리가 고려한 모든 기술 task에서 state-of-the-art 결과를 도출한다. 모든 task(GSM8k 제외)에서 이전 결과 대비 상당한 개선이 이루어졌다.

우리의 주요 초점은 few-shot 평가에 있지만, 우리는 Minerva를 MATH 데이터셋에 대해 fine-tuning하는 시도도 했다. 이 시도에서는 어떠한 개선도 관찰되지 않았지만, PaLM을 MATH 데이터셋에 fine-tuning했을 때는 상당한 개선이 있었다. 이는 비지도 학습 데이터셋의 품질과 다양성이 향상될수록 표준 fine-tuning의 한계 효용이 감소한다는 것을 시사한다. 자세한 내용은 Appendix E.3에서 확인할 수 있다.

3.1 Basic arithmetic

Appendix H에서는 Minerva 540B단순 산술 task 성능을 연구한다. 이 모델은 10자리 덧셈에서 80% 이상의 정확도를 달성했으며, 18자리 덧셈에서는 20% 이상의 정확도를 기록했다.

4 Performance Analysis

4.1 Model Mistakes

모델이 저지르는 오류 유형을 더 잘 이해하기 위해, 우리는 Minerva 8B와 Minerva 62B가 **높은 확신도를 가진 다수결 결정(majority decisions)**을 내린 216개의 문제에 대한 성능을 비교했다. 구체적으로, 우리는 최고 답변이 최소 15%의 득표율을 얻은 예시들을 선택했으며, 이 중 **Minerva 8B는 정답이고 Minerva 62B는 오답인 경우(15개 샘플)**와 **그 반대인 경우(201개 샘플)**를 분석했다. 각 카테고리별 설명과 예시는 Appendix I.2에 제시되어 있다.

Table 4에서 볼 수 있듯이, 8B 모델의 주요 오류는 잘못된 추론(reasoning) 또는 계산과 관련이 있었다. 계산 오류 중 상당수는 비교적 경미한 산술적 실수였다. **해결책이 너무 짧은 경우(모델이 중간 추론 단계 없이 즉시 오답을 생성하는 경우)**는 상대적으로 드물었다. 마지막으로, 몇몇 경우에는 모델이 실제하지 않는 방정식이나 수학적 사실을 환각(hallucinate)하는 경우도 있었다.

62B 모델이 오답을 낸 샘플들에서는, 주요 실패 모드 역시 잘못된 추론과 잘못된 계산이었다. 요약하자면, 우리는 62B Minerva 모델이 8B 모델의 대부분의 기술을 유지하면서도, 추론 및 계산의 견고성(robustness)을 모두 향상시켰음을 발견했다.

Table 4: 62B 모델은 정답을 맞혔지만 8B 모델은 틀린 201개 샘플 중 8B Minerva 모델의 실패 모드.

Type of mistakesOccurrences
Incorrect reasoning82
Incorrect calculation70
Misunderstands question22
Uses incorrect fact16
Solution too short4
Hallucinated math objects4
Other mistakes3

4.2 False Positives

정량적 추론 문제 해결을 위한 우리의 접근 방식에서, 우리는 문제의 최종 답안이 올바른지 여부를 자동으로 검증할 수 있지만, 모델의 추론 과정(chain of reasoning)을 자동으로 검증하는 방법은 없다. 이로 인해 false positive의 가능성이 열려 있다. 즉, 최종 답안은 올바르지만 추론 과정이 불완전하거나 잘못된 샘플들이 존재할 수 있다.
우리는 **MATH 데이터셋에서 무작위로 100개의 질문(난이도별 20개)**을 선택하고, 62B 모델에서 zero temperature로 샘플링된 답안을 함께 추출했다. 그런 다음, false positive rate를 확인하기 위해 답안을 수동으로 검사했다. false positive rate최종 답안이 올바른 예시 중 false positive 예시의 비율을 의미한다 (Table 5 참조). 우리는 전반적인 false positive rate가 낮음을 발견했지만, 난이도 수준이 높아질수록 증가하는 경향을 보였다.

우리가 pass@1과 majority voting을 주요 평가 지표로 삼는 이유는 부분적으로 이들이 pass@k보다 false positive에 덜 민감하기 때문이다 (Li et al., 2022). 62B 모델의 **pass@256 정확도는 84.5%**이지만, 이 중 일부는 false positive에 해당한다. 우리는 majority voting에서는 실패했지만, 단 하나의 올바른 답안으로 인해 pass@k에서는 통과한 샘플들을 검사했으며, 이러한 방식으로 선택된 샘플들 중 **pass@256의 false positive rate는 30%**로 추정된다. false positive를 제거한 후, pass@256 정확도는 68%보다 클 것으로 추정된다 (자세한 내용은 Appendix I.3 참조).

Table 5: MATH 데이터셋에서 62B 모델의 난이도별 false positive rate 추정치.
평균은 MATH 데이터셋에 대한 추정 false positive rate이며, 난이도별 false positive rate를 positive rate로 가중 평균하여 얻어진다.

Difficulty level
12345Average
False positive rate<5%<5 \%10%10 \%<5%<5 \%15%15 \%30%30 \%8%8 \%

5 Memorization

Minerva의 솔루션을 해석하는 데 있어 핵심적인 질문은 성능이 진정한 분석 능력을 반영하는지, 아니면 단순히 암기(rote memorization)의 결과인지이다. 이는 언어 모델이 학습 데이터의 일부를 암기하는 경우가 많다는 선행 연구들(Trinh and Le, 2018; Radford et al., 2019; Carlini et al., 2022)이 많기 때문에 특히 중요하다. 모델 솔루션을 검토해보면, 제곱근의 수치나 삼각 함수 항등식과 같은 중간 사실(intermediate facts)의 암기가 모델 솔루션의 중요한 요소임을 알 수 있다. 진정으로 강력한 성능은 중간 사실의 회상과 진정한 솔루션 합성 능력을 결합하는 것이다. 우리는 모델 성능이 평가 세트에 있는 명시적인 문제와 솔루션을 암기한 결과인 강한 형태의 암기뿐만 아니라, 동일한 질문에 대한 다른 답변을 암기한 약한 형태의 암기도 조사하고자 한다.

모델이 학습 데이터에서 암기한 정보를 회상하여 문제를 해결하는 정도를 평가하기 위해, 우리는 MATH 데이터셋에 대해 세 가지 분석을 수행한다. 첫째, 우리는 학습 코퍼스에서 문제와 솔루션을 직접 검색한다. 다음으로, 문제의 수정된 버전을 생성하고 이러한 변경에 대한 모델의 견고성을 평가한다. 마지막으로, ground truth 솔루션과 모델이 생성한 솔루션 간의 중복 정도를 측정하고 이 유사성이 모델 성능에 미치는 영향을 측정한다. 전반적으로, 우리는 모델의 성능이 암기에 기인한다는 증거를 거의 찾지 못했다.

5.1 Training and Evaluation Dataset Overlap

우리는 62B 파라미터 모델이 정답을 맞춘 문제들을 선별하고, 그중 가장 높은 다수결 투표 점수(majority vote score)를 받은 100개의 문제를 필터링했다. 이는 다수결 투표 점수가 높은 문제가 모델에 의해 암기되었을 가능성이 더 높을 것이라는 가정에 기반한다.
각 질문-답변 쌍에 대해, 우리는 Math Web Pages 데이터셋에서 500자 단위(chunks of 500 characters)로 BLEU 점수를 계산했다 (BLEU 점수의 히스토그램은 Appendix Figure 10에 제시되어 있다).
이후, 가장 높은 BLEU 점수를 받은 250개의 문서를 수동으로 검토했다.
상위 일치 항목 중 상당수는 수학 문제와 풀이가 있는 숙제 도움 사이트에서 나왔지만, 검토 중인 MATH의 하위 집합에 있는 문제와 일치하는 질문은 없었다. 우리는 이 250개의 세그먼트를 Appendix J.1에 포함했다.
MATH의 일부 문제는 웹에서 찾을 수 있다는 점을 인지하고 있다. 그럼에도 불구하고, 이 분석은 해당 문제들이 우리의 데이터 수집 과정을 통해 모델에 포함되지 않았음을 결론짓는다.

5.2 Performance on Modified MATH Problems

메모리화(memorization)를 추가적으로 조사하기 위해, 우리는 62B 모델이 다수결 투표(majority voting) 하에 올바르게 답변한 문제 20개를 무작위로 선정했다. 우리는 각 문제를 문제 문구에 사소한 변경(framing)을 가하거나, 문제에 나타난 숫자를 변경하고 그에 따라 해답을 수정하는 방식으로 수동으로 수정했다. 그런 다음, 수정 전후의 샘플링된 해답에 대한 정확도를 비교했다. 결과는 Figure 5에 나와 있다. 두 경우 모두 수정 전후의 정확도가 서로 상관관계를 보였으며, 원래 문제 구성에 유리한 명확한 편향은 없었다. 이는 메모리화가 최소한으로 이루어졌음을 시사한다. 수정된 문제들은 Appendix J.2에 나열되어 있다.

Figure 5: MATH 데이터셋에서 메모리화 부족을 나타내는 결과. 왼쪽, 중앙: MATH 데이터셋의 원본 질문과 수정된 버전의 정확도. 각 점은 하나의 질문을 나타낸다. xx축은 원본 질문의 정확도이고, yy축은 수정된 질문의 정확도이다. 오른쪽: ground truth 해답과의 BLEU 점수가 xx축 값보다 작거나 같은 샘플에 대해서만 계산된 다수결 투표 정확도.

5.3 BLEU Score Between Ground Truth and Generated Solutions

우리는 정답(ground truth)과 모델이 생성한 답안 사이의 BLEU 점수를 계산하여 모델의 해답 암기(memorization) 여부를 탐지하고자 한다. 이를 위해 62B 모델을 사용하고, MATH 데이터셋의 각 문제당 256개의 샘플을 분석한다.
먼저, 모든 정답 샘플에 대한 overlap 통계를 계산한다. 그 결과, 5,000개의 테스트 문제 중 160개 문제에서 BLEU 점수가 80 이상인 샘플이 존재함을 확인했다 (Appendix J.3 참조). 이러한 샘플들은 대체로 짧은 해답인 경향이 있었다.
답안 유사도가 성능에 미치는 영향을 이해하기 위해, 우리는 특정 BLEU 점수 임계값 이상의 모델 샘플을 제거하고, majority vote 정확도를 재계산했다.
그 결과, majority vote 성능은 상대적으로 낮은 유사도에서도 강건함을 발견했다 (Figure 5 참조). 이는 모델의 성능이 정답과 매우 유사한 출력에 기인한다고 볼 수 없음을 시사한다.

6 Conclusions and Discussion

본 연구에서는 자연어로 표현된 수학적 추론을 사용하여 문제를 해결하는 방식에 기반한 정량적 추론(quantitative reasoning) 접근법을 취한다. 우리는 고품질 수학 데이터셋으로 대규모 language model을 학습시킴으로써, 논리적 추론, 수치 계산, 기호 조작을 요구하는 task에서 강력한 성능을 달성할 수 있음을 보여준다.

우리의 모델은 외부 도구를 사용하지 않으며, 추론 시에는 오직 autoregressive sampling에만 의존하여 이러한 성능을 달성한다. 정량적 추론에 대한 보완적인 접근법으로는 **코드 생성 모델(code-generating models)**과 **형식적 방법(formal methods)**이 있다. 이들은 모두 정량적 문제를 추론하고 해결할 수 있는 에이전트라는 공통된 목표를 향한 다른 경로들이다. 우리는 이러한 에이전트가 이 모든 접근법의 유용한 요소들을 결합해야 한다고 믿는다.

6.1 Limitations of Our Approach

우리의 정량적 추론(quantitative reasoning) 접근 방식에는 몇 가지 한계점이 있다.
첫째, 우리는 모델 답변의 정확성을 자동으로 검증할 방법이 없다. 이는 자동 검증이 본질적으로 가능한 formal한 접근 방식과는 대조적이다.
둘째, 우리 모델은 계산기나 Python interpreter와 같은 외부 도구에 접근할 수 없다. 따라서 복잡한 수치 계산이 필요한 정량적 추론 task를 수행하는 능력에 제한이 있다.
셋째, 우리 모델은 대량의 데이터로 학습되었기 때문에, 모델이 습득한 특정 능력에 대해 직접적인 제어력이 거의 없다.

6.2 Societal Impact

일반적인 환경에서 정량적 추론 문제를 해결할 수 있는 인공신경망은 사회에 상당한 영향을 미칠 잠재력을 가지고 있다. Minerva는 이러한 방향으로 나아가는 한 걸음이지만, 아직 이 목표를 달성하기에는 멀었으며, 따라서 사회적 영향은 제한적이다. 모델의 성능은 여전히 인간의 성능에 훨씬 못 미치며, 더욱이 출력의 정확성을 자동으로 검증할 방법이 없다. 이러한 문제들이 해결된다면, 이 모델의 영향은 전반적으로 긍정적일 것으로 예상된다. 직접적인 응용 분야로는 접근 가능하고 저렴한 수학 튜터가 될 수 있으며, 이는 교육 불평등을 개선하는 데 도움이 될 수 있다.