364 lines
7.2 KiB
Go
364 lines
7.2 KiB
Go
package intcodeprocessor
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
)
|
|
|
|
const (
|
|
OP_ADD = 1
|
|
OP_MLT = 2
|
|
OP_INP = 3
|
|
OP_OUT = 4
|
|
OP_JIT = 5
|
|
OP_JIF = 6
|
|
OP_ILT = 7
|
|
OP_IEQ = 8
|
|
OP_RBS = 9
|
|
OP_EXT = 99
|
|
)
|
|
|
|
const (
|
|
MODE_POS = iota
|
|
MODE_IMM
|
|
MODE_REL
|
|
)
|
|
|
|
const (
|
|
RET_ERR = iota - 1
|
|
RET_OK
|
|
RET_DONE
|
|
)
|
|
|
|
type Program struct {
|
|
originalCode []int
|
|
code []int
|
|
ptr int
|
|
relBase int
|
|
|
|
state int
|
|
error error
|
|
|
|
waitingForInput bool
|
|
input chan int
|
|
waitingForOutput bool
|
|
output chan int
|
|
|
|
debug bool
|
|
}
|
|
|
|
func NewProgram(prog []int) *Program {
|
|
p := new(Program)
|
|
p.originalCode = make([]int, len(prog))
|
|
p.code = make([]int, len(prog))
|
|
copy(p.originalCode, prog)
|
|
p.Reset()
|
|
return p
|
|
}
|
|
|
|
func (p *Program) EnableDebug() {
|
|
p.debug = true
|
|
}
|
|
|
|
func (p *Program) DisableDebug() {
|
|
p.debug = false
|
|
}
|
|
|
|
func (p *Program) DebugLog(l string) {
|
|
if p.debug {
|
|
fmt.Print(l)
|
|
}
|
|
}
|
|
|
|
func (p *Program) Reset() {
|
|
copy(p.code, p.originalCode)
|
|
p.ptr = 0
|
|
p.state = RET_OK
|
|
p.error = nil
|
|
p.waitingForInput = false
|
|
p.waitingForOutput = false
|
|
p.input = make(chan int)
|
|
p.output = make(chan int)
|
|
p.relBase = 0
|
|
}
|
|
|
|
func (p *Program) GetCode() []int {
|
|
return p.code
|
|
}
|
|
|
|
func (p *Program) State() int {
|
|
return p.state
|
|
}
|
|
|
|
func (p *Program) Error() error {
|
|
return p.error
|
|
}
|
|
|
|
func (p *Program) Run() int {
|
|
for {
|
|
p.state = p.Step()
|
|
if p.state != RET_OK {
|
|
return p.state
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Program) Step() int {
|
|
if len(p.code) < p.ptr {
|
|
p.error = errors.New("Pointer Exception")
|
|
return RET_ERR
|
|
}
|
|
p.DebugLog(p.String() + "\n")
|
|
intcode := p.readNext()
|
|
p.ptr++
|
|
switch p.opCode(intcode) {
|
|
case OP_ADD:
|
|
v1, v2, v3 := p.readNextThree()
|
|
p.DebugLog(fmt.Sprintf("ADD %d (%d, %d, %d)\n", intcode, v1, v2, v3))
|
|
p.ptr = p.ptr + 3
|
|
p.opAdd(intcode, v1, v2, v3)
|
|
if p.error != nil {
|
|
return RET_ERR
|
|
}
|
|
return RET_OK
|
|
case OP_MLT:
|
|
v1, v2, v3 := p.readNextThree()
|
|
p.DebugLog(fmt.Sprintf("MLT %d (%d, %d, %d)\n", intcode, v1, v2, v3))
|
|
p.ptr = p.ptr + 3
|
|
p.opMult(intcode, v1, v2, v3)
|
|
if p.error != nil {
|
|
return RET_ERR
|
|
}
|
|
return RET_OK
|
|
case OP_INP:
|
|
v1 := p.readNext()
|
|
p.DebugLog(fmt.Sprintf("INP %d (%d)\n", intcode, v1))
|
|
p.ptr++
|
|
p.opInp(intcode, v1)
|
|
if p.error != nil {
|
|
return RET_ERR
|
|
}
|
|
return RET_OK
|
|
case OP_OUT:
|
|
v1 := p.readNext()
|
|
p.ptr++
|
|
p.DebugLog(fmt.Sprintf("OUT %d (%d)\n", intcode, v1))
|
|
p.opOut(intcode, v1)
|
|
if p.error != nil {
|
|
return RET_ERR
|
|
}
|
|
return RET_OK
|
|
case OP_JIT:
|
|
v1, v2 := p.readNextTwo()
|
|
p.DebugLog(fmt.Sprintf("JIT %d (%d, %d)\n", intcode, v1, v2))
|
|
p.ptr = p.ptr + 2
|
|
p.opJumpIfTrue(intcode, v1, v2)
|
|
if p.error != nil {
|
|
return RET_ERR
|
|
}
|
|
return RET_OK
|
|
case OP_JIF:
|
|
v1, v2 := p.readNextTwo()
|
|
p.DebugLog(fmt.Sprintf("JIF %d (%d, %d)\n", intcode, v1, v2))
|
|
p.ptr = p.ptr + 2
|
|
p.opJumpIfFalse(intcode, v1, v2)
|
|
if p.error != nil {
|
|
return RET_ERR
|
|
}
|
|
return RET_OK
|
|
case OP_ILT:
|
|
v1, v2, dest := p.readNextThree()
|
|
p.DebugLog(fmt.Sprintf("ILT %d (%d, %d, %d)\n", intcode, v1, v2, dest))
|
|
p.ptr = p.ptr + 3
|
|
p.opIfLessThan(intcode, v1, v2, dest)
|
|
if p.error != nil {
|
|
return RET_ERR
|
|
}
|
|
return RET_OK
|
|
case OP_IEQ:
|
|
v1, v2, dest := p.readNextThree()
|
|
p.DebugLog(fmt.Sprintf("IEQ %d (%d, %d, %d)\n", intcode, v1, v2, dest))
|
|
p.ptr = p.ptr + 3
|
|
p.opIfEqual(intcode, v1, v2, dest)
|
|
if p.error != nil {
|
|
return RET_ERR
|
|
}
|
|
return RET_OK
|
|
case OP_RBS:
|
|
v1 := p.readNext()
|
|
p.DebugLog(fmt.Sprintf("RBS %d (%d)\n", intcode, v1))
|
|
p.ptr = p.ptr + 1
|
|
p.opModRelBase(intcode, v1)
|
|
if p.error != nil {
|
|
return RET_ERR
|
|
}
|
|
return RET_OK
|
|
|
|
case OP_EXT:
|
|
p.DebugLog(fmt.Sprintf("EXT %d\n", intcode))
|
|
return RET_DONE
|
|
}
|
|
p.error = errors.New(fmt.Sprintf("Invalid OpCode (%d)", intcode))
|
|
p.DebugLog(p.String())
|
|
return RET_ERR
|
|
}
|
|
|
|
func (p *Program) GetCurrentOpCode() int {
|
|
return p.code[p.ptr]
|
|
}
|
|
|
|
func (p *Program) GetProgramValueAt(idx int) int {
|
|
p.ensureLength(idx)
|
|
return p.code[idx]
|
|
}
|
|
|
|
func (p *Program) SetProgramValueAt(idx, val int) {
|
|
p.ensureLength(idx)
|
|
p.code[idx] = val
|
|
}
|
|
|
|
func (p *Program) NeedsInput() bool {
|
|
return p.waitingForInput
|
|
}
|
|
|
|
func (p *Program) NeedsOutput() bool {
|
|
return p.waitingForOutput
|
|
}
|
|
|
|
func (p *Program) opCode(intcode int) int {
|
|
return intcode % 100
|
|
}
|
|
|
|
func (p *Program) paramMode(intcode, pNum int) int {
|
|
plc := math.Pow10(pNum + 2)
|
|
return ((intcode - p.opCode(intcode)) / int(plc)) % 10
|
|
}
|
|
|
|
func (p *Program) ensureLength(idx int) {
|
|
for len(p.code) < idx+1 {
|
|
p.code = append(p.code, 0)
|
|
}
|
|
}
|
|
|
|
func (p *Program) readNext() int {
|
|
p.ensureLength(p.ptr)
|
|
return p.code[p.ptr]
|
|
}
|
|
|
|
func (p *Program) readNextTwo() (int, int) {
|
|
p.ensureLength(p.ptr + 1)
|
|
return p.code[p.ptr], p.code[p.ptr+1]
|
|
}
|
|
|
|
func (p *Program) readNextThree() (int, int, int) {
|
|
p.ensureLength(p.ptr + 2)
|
|
return p.code[p.ptr], p.code[p.ptr+1], p.code[p.ptr+2]
|
|
}
|
|
|
|
func (p *Program) get(mode, v int) int {
|
|
if mode == MODE_POS {
|
|
p.ensureLength(v)
|
|
return p.code[v]
|
|
} else if mode == MODE_REL {
|
|
p.ensureLength(p.relBase + v)
|
|
return p.code[p.relBase+v]
|
|
}
|
|
return v
|
|
}
|
|
|
|
func (p *Program) set(mode, idx, v int) {
|
|
if mode == MODE_POS {
|
|
p.ensureLength(idx)
|
|
p.code[idx] = v
|
|
} else if mode == MODE_REL {
|
|
p.ensureLength(p.relBase + idx)
|
|
p.code[p.relBase+idx] = v
|
|
}
|
|
}
|
|
|
|
func (p *Program) Input(v int) {
|
|
p.input <- v
|
|
p.waitingForInput = false
|
|
}
|
|
|
|
func (p *Program) Output() int {
|
|
v := <-p.output
|
|
p.waitingForOutput = false
|
|
return v
|
|
}
|
|
|
|
func (p *Program) opAdd(intcode, a1, a2, dest int) {
|
|
a1md, a2md, destmd := p.paramMode(intcode, 0), p.paramMode(intcode, 1), p.paramMode(intcode, 2)
|
|
p.set(destmd, dest, p.get(a1md, a1)+p.get(a2md, a2))
|
|
}
|
|
|
|
func (p *Program) opMult(intcode, a1, a2, dest int) {
|
|
a1md, a2md, destmd := p.paramMode(intcode, 0), p.paramMode(intcode, 1), p.paramMode(intcode, 2)
|
|
p.set(destmd, dest, p.get(a1md, a1)*p.get(a2md, a2))
|
|
}
|
|
|
|
func (p *Program) opInp(intcode, dest int) {
|
|
destmd := p.paramMode(intcode, 0)
|
|
p.waitingForInput = true
|
|
p.set(destmd, dest, <-p.input)
|
|
p.waitingForInput = false
|
|
}
|
|
|
|
func (p *Program) opOut(intcode, val int) {
|
|
valmd := p.paramMode(intcode, 0)
|
|
ret := p.get(valmd, val)
|
|
p.waitingForOutput = true
|
|
p.output <- ret
|
|
}
|
|
|
|
func (p *Program) opJumpIfTrue(intcode, v1, v2 int) {
|
|
v1md, v2md := p.paramMode(intcode, 0), p.paramMode(intcode, 1)
|
|
if p.get(v1md, v1) != 0 {
|
|
p.ptr = p.get(v2md, v2)
|
|
}
|
|
}
|
|
|
|
func (p *Program) opJumpIfFalse(intcode, v1, v2 int) {
|
|
v1md, v2md := p.paramMode(intcode, 0), p.paramMode(intcode, 1)
|
|
if p.get(v1md, v1) == 0 {
|
|
p.ptr = p.get(v2md, v2)
|
|
}
|
|
}
|
|
|
|
func (p *Program) opIfLessThan(intcode, v1, v2, dest int) {
|
|
v1md, v2md, destmd := p.paramMode(intcode, 0), p.paramMode(intcode, 1), p.paramMode(intcode, 2)
|
|
if p.get(v1md, v1) < p.get(v2md, v2) {
|
|
p.set(destmd, dest, 1)
|
|
} else {
|
|
p.set(destmd, dest, 0)
|
|
}
|
|
}
|
|
|
|
func (p *Program) opIfEqual(intcode, v1, v2, dest int) {
|
|
v1md, v2md, destmd := p.paramMode(intcode, 0), p.paramMode(intcode, 1), p.paramMode(intcode, 2)
|
|
if p.get(v1md, v1) == p.get(v2md, v2) {
|
|
p.set(destmd, dest, 1)
|
|
} else {
|
|
p.set(destmd, dest, 0)
|
|
}
|
|
}
|
|
|
|
func (p *Program) opModRelBase(intcode, v1 int) {
|
|
v1md := p.paramMode(intcode, 0)
|
|
p.relBase = p.relBase + p.get(v1md, v1)
|
|
}
|
|
|
|
func (p Program) String() string {
|
|
var ret string
|
|
ret = ret + fmt.Sprintf("(PTR: %d, RBS: %d)\n", p.ptr, p.relBase)
|
|
for k := range p.code {
|
|
if k == p.ptr {
|
|
ret = fmt.Sprintf("%s [%d]", ret, p.code[k])
|
|
} else {
|
|
ret = fmt.Sprintf("%s %d", ret, p.code[k])
|
|
}
|
|
}
|
|
return ret
|
|
}
|