#!/usr/bin/python3

# Signetics 8X300, 8X305 disassembler
# Copyright 2016 Eric Smith <spacewar@gmail.com>

# This program is free software: you can redistribute it and/or modify
# it under the terms of version 3 of the GNU General Public License
# as published by the Free Software Foundation.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import argparse
import sys
import math
from enum import Enum, IntEnum

#from s8x30x import S8X30x, CpuType
#from intelhex import IntelHex
#from memory import Memory
#from wd1000 import WD1000
#from wd1001 import WD1001

# This began as s8x30x, the 8x300 disassembler written by Eric Smith
# It was cloned from https://github.com/brouhaha/s8x30x
# Tom Trebisky shortly began mercilessly hacking on it,
# as follows:
#
# -- I tossed the IntelHex support, which I don't need.
# -- I merged all the python files into this one.
#    This lets me place a single file into /usr/local/bin or whatever
#    a step backwards in the name of convenience.
# -- I tossed the wd1000 and wd1001 files



# These were memory.py

class LengthMismatch(Exception):
    pass

# in my case, we have two input files (msb, lsb)
# so we will end up with two byte strings, each 2048 long
# __len__ and __getitem__ make this memory object look like an
# array, so it responds to len() and []
# len(fw) returns 2048 in my case.
# fw[addr] returns [lsb,msb], a list of two items

class Memory:
    def __init__(self, args):
        #self.data = data
        self.data = [f.read() for f in args]
        l = [len(d) for d in self.data]
        if l[1:] != l[:-1]:
            raise LengthMismatch()
        self.size = l[0]

    def __len__(self):
        return self.size

    def __getitem__(self, address):
        return [self.data[i][address] for i in range(len(self.data))]

# end of memory.py

# ======== This was s8x30x.py

CpuType = Enum('CpuType', ['s8x300', 's8x305'])

class UnknownMnemonic(Exception):
    def __init__(self, mnem):
        super().__init__('unknown mnemonic "%s"' % mnem)

class NoMatchingForm(Exception):
    def __init__(self):
        super().__init__('no matching form')

class OperandOutOfRange(Exception):
    def __init__(self):
        super().__init__('operand out of range')

class BadInstruction(Exception):
    def __init__(self, instr):
        super().__init__('bad instruction %04x' % instr)

class InternalError(Exception):
    pass


class Reg(IntEnum):
    aux  = 0o00;  r0  = 0o00
    r1   = 0o01
    r2   = 0o02
    r3   = 0o03
    r4   = 0o04
    r5   = 0o05
    r6   = 0o06
    ivl  = 0o07;  r7  = 0o07  # left bank register
    ovf  = 0o10;  r10 = 0o10
    r11  = 0o11
    r12  = 0o12               # 8x305 only, can't use for normal XMIT
    r13  = 0o13               # 8x305 only, can't use for normal XMIT
    r14  = 0o14               # 8x305 only
    r15  = 0o15               # 8x305 only
    r16  = 0o16               # 8x305 only
    ivr  = 0o17;  r17 = 0o10  # right bank register
    liv0 = 0o20
    liv1 = 0o21
    liv2 = 0o22
    liv3 = 0o23
    liv4 = 0o24
    liv5 = 0o25
    liv6 = 0o26
    liv7 = 0o27
    riv0 = 0o30
    riv1 = 0o31
    riv2 = 0o32
    riv3 = 0o33
    riv4 = 0o34
    riv5 = 0o35
    riv6 = 0o36
    riv7 = 0o37

    def is_iv(self, cpu_type = CpuType.s8x300):
        return self >= self.__class__.liv0

    def rightmost_liv_bit(self, cpu_type = CpuType.s8x300):
        if not self.is_iv(cpu_type):
            assert(InternalError('register %s is not IV' % self.name))
        return 7 - (self.value & 7)

    def is_src_reg(self, cpu_type = CpuType.s8x300):
        c = self.__class__
        if self.is_iv(cpu_type):
            return False
        if self < c.ivl or self is c.ovf or self is c.r11:
            return True;
        return cpu_type is CpuType.s8x305

    def is_dest_reg(self, cpu_type = CpuType.s8x300):
        c = self.__class__
        if self.is_iv(cpu_type):
            return False
        if self is c.ovf:
            return False
        if self <= c.r11 or self is c.ivr:
            return True;
        return cpu_type is CpuType.s8x305
        

# operand type
# defined outside the I89 class and with very short name because
# it will be used a lot in the __inst_set class attribute of I89
OT = Enum('OT', ['sr',           # source register
                 'dr',           # destination register
                 'siv', 'div',   # IV 
                 'blen',         # length, 1-8 bits (8 encoded as zero)
                 'brot',         # rotation, 0-7 bits
                 'imm',          # immediate value, 5 or 8 bit
                 'jmp5', 'jmp8', # jump target (within page)
                 'jmp13'         # jump target (absolute address)
                ])


def bit_count(v):
    return bin(v).count('1')


class BitField:
    def __init__(self, byte_count = 0):
        self.width = 0  # width of the field within the instruction
        self.mask = bytearray(byte_count)

    def __repr__(self):
        return 'BitField(width = %d, mask = %s' % (self.width, str(self.mask))

    def append(self, mask_byte):
        self.mask.append(mask_byte)
        self.width += bit_count(mask_byte)

    def pad_length(self, length):
        if len(self.mask) < length:
            self.mask += bytearray(length - len(self.mask))

    def insert(self, bits, value):
        assert isinstance(value, int)
        for i in range(len(bits)):
            for b in [1 << j for j in range(8)]:
                if self.mask[i] & b:
                    if value & 1:
                        bits[i] |= b
                    value >>= 1
        #assert value == 0  # XXX causes negative 8-bit immediates to fail
        

# An instruction form is a variant of an instruction that takes
# specific operand types.
class Form:
    @staticmethod
    def __byte_parse(bs):
        b = 0
        m = 0
        f = { }
        for i in range(8):
            c = bs[7-i]
            if c == '0':
                m |= (1 << i)
            elif c == '1':
                b |= (1 << i)
                m |= (1 << i)
            else:
                if c not in f:
                    f[c] = 0
                f[c] |= (1 << i)
        return b, m, f

    @staticmethod
    def __encoding_parse(encoding):
        ep_debug = False
        if ep_debug:
            print('encoding', encoding)
        encoding = encoding.replace(' ', '')
        bits = []
        mask = []
        fields = { }
        second_flag = False
        while len(encoding):
            assert len(encoding) >= 8
            byte = encoding[:8]
            encoding = encoding[8:]
            if ep_debug:
                print('byte', byte)
            b, m, f = Form.__byte_parse(byte)
            if ep_debug:
                print('b: ', b, 'm:', m, 'f:', f)
            bits.append(b)
            mask.append(m)
            for k in f:
                if k not in fields:
                    fields[k] = BitField(len(bits)-1)
                fields[k].append(f[k])
        if ep_debug:
            print('fields before:', fields)
        for k in fields:
            fields[k].pad_length(len(bits))
        if ep_debug:
            print('fields after:', fields)
        return bits, mask, fields

    def __init__(self, encoding, operands):
        self.operands = operands
        self.encoding = encoding
        self.bits, self.mask, self.fields = Form.__encoding_parse(encoding)

    def __len__(self):
        return len(self.bits)

    def insert_fields(self, fields):
        bits = bytearray(self.bits)
        assert set(self.fields.keys()) == set(fields.keys())
        for k, bitfield in self.fields.items():
            bitfield.insert(bits, fields[k])
        return bits
        


# An instruction has a single mnemonic, but possibly multiple
# forms.
class Inst:
    def __init__(self, mnem, *forms):
        self.mnem = mnem
        self.forms = forms


class S8X30x:
    # The source operand precedes the destination operand
    __inst_set = [
        Inst('nop',   Form('00000000 00000000', ())),
        Inst('xml',   Form('11001010 iiiiiiii', (OT.imm,))),
        Inst('xmr',   Form('11001011 iiiiiiii', (OT.imm,))),
        
        Inst('move',  Form('000sssss rrrddddd', (OT.sr,  OT.brot, OT.dr)),	# source -> dest
                      Form('000sssss lllddddd', (OT.sr,  OT.blen, OT.div)),
                      Form('000sssss lllddddd', (OT.siv, OT.blen, OT.dr)),
                      Form('000sssss lllddddd', (OT.siv, OT.blen, OT.div))),

        Inst('add',   Form('001sssss rrrddddd', (OT.sr,  OT.brot, OT.dr)),	# source + AUX -> dest, updates OVF
                      Form('001sssss lllddddd', (OT.sr,  OT.blen, OT.div)),
                      Form('001sssss lllddddd', (OT.siv, OT.blen, OT.dr)),
                      Form('001sssss lllddddd', (OT.siv, OT.blen, OT.div))),
                                              
        Inst('and',   Form('010sssss rrrddddd', (OT.sr,  OT.brot, OT.dr)),	# source & AUX -> dest
                      Form('010sssss lllddddd', (OT.sr,  OT.blen, OT.div)),
                      Form('010sssss lllddddd', (OT.siv, OT.blen, OT.dr)),
                      Form('010sssss lllddddd', (OT.siv, OT.blen, OT.div))),
                                              
        Inst('xor',   Form('011sssss rrrddddd', (OT.sr,  OT.brot, OT.dr)),	# source & AUX -> dest
                      Form('011sssss lllddddd', (OT.sr,  OT.blen, OT.div)),
                      Form('011sssss lllddddd', (OT.siv, OT.blen, OT.dr)),
                      Form('011sssss lllddddd', (OT.siv, OT.blen, OT.div))),
                                              
        Inst('xec',   Form('100sssss jjjjjjjj', (OT.jmp8, OT.sr)),               # execute intrustion at S+i
                      Form('100sssss llljjjjj', (OT.jmp5, OT.siv, OT.blen))),
                                              
                                              
        Inst('nzt',   Form('101sssss jjjjjjjj', (OT.sr, OT.jmp8)),               # jump if S is non-zero
                      Form('101sssss llljjjjj', (OT.siv, OT.blen, OT.jmp5))),
                                              
        Inst('xmit',  Form('110ddddd iiiiiiii', (OT.imm, OT.dr)),               # immediate -> dest
                      Form('110ddddd llliiiii', (OT.imm, OT.blen, OT.div))),
                                              
        Inst('jmp',   Form('111jjjjj jjjjjjjj', (OT.jmp13,)))
    ]


    def __opcode_init(self):
        self.__inst_by_mnemonic = { }
        self.__inst_by_opcode = { }
        for inst in self.__inst_set:
            assert inst.mnem not in self.__inst_by_mnemonic
            self.__inst_by_mnemonic[inst.mnem] = inst
            form = inst.forms[0]  # assumes all forms of an inst have same opcode
            opcode = form.bits[0] >> 5
            if opcode not in self.__inst_by_opcode:
                self.__inst_by_opcode[opcode] = []
            self.__inst_by_opcode[opcode] += [inst]

    def _opcode_table_print(self):
        for mnem in sorted(self.__inst_by_mnemonic.keys()):
            inst = self.__inst_by_mnemonic[mnem]
            for form in inst.forms:
                print('%-04s:' % mnem, form.encoding, form.operands)


    @staticmethod
    def __extract_field(inst, fields, f):
        width = 0
        v = 0
        for i in range(min(len(inst), len(fields[f].mask))):
            for j in reversed(range(8)):
                if fields[f].mask[i] & (1 << j):
                    v = (v << 1) | ((inst[i] >> j) & 1)
                    width += 1
        return v


    @staticmethod
    def encoding_match(d, form):
        for i in range(len(d)):
            if d[i] & form.mask[i] != form.bits[i] & form.mask[i]:
                return False
        return True


    def form_search(self, fw, pc, inst):
        opcode = fw[pc][:2]
        for form in inst.forms:
            if not self.encoding_match(opcode, form):
                continue
            fields = { }
            for f in form.fields:
                fields[f] = self.__extract_field(opcode, form.fields, f)
            if 's' in fields:
                sr = Reg(fields['s'])
                if OT.sr in form.operands and not sr.is_src_reg(self.cpu_type):
                    continue
                elif OT.siv in form.operands and not sr.is_iv(self.cpu_type):
                    continue
            if 'd' in fields:
                dr = Reg(fields['d'])
                if OT.dr in form.operands and not dr.is_dest_reg(self.cpu_type):
                    continue
                elif OT.div in form.operands and not dr.is_iv(self.cpu_type):
                    continue
            if 'j' in fields:
                if OT.jmp8 in form.operands:
                    # XXX should it be (pc+1) & 0xff00?
                    fields['j'] = fields['j'] + (pc & 0xff00)
                elif OT.jmp5 in form.operands:
                    # XXX should it be (pc+1) & 0xffe0?
                    fields['j'] = fields['j'] + (pc & 0xffe0)
            return form, fields
        return None, None


    def inst_search(self, fw, pc):
        opcode = fw[pc][0] >> 5
        for inst in self.__inst_by_opcode[opcode]:
            form, fields = self.form_search(fw, pc, inst)
            if form is not None:
                return inst, form, fields
        raise BadInstruction(opcode)

    @staticmethod
    def ihex(v):
        s = '%xh' % v
        if s[0].isalpha():
            s = '0' + s
        return s


    def disassemble_inst(self, fw, pc, symtab_by_value = {}, disassemble_operands = True):
        try:
            inst, form, fields = self.inst_search(fw, pc)
        except BadInstruction:
            return 'dw      ', '%s' % self.ihex((fw[pc][0] << 8) + fw[pc][1]), {}

        s = '%-6s' % inst.mnem
        operands = []

        if disassemble_operands:
            ftemp = fields.copy()
            blen_used = False
            for operand in form.operands:
                if operand == OT.blen or operand == OT.brot:
                    continue
                elif operand == OT.sr:
                    value = Reg(ftemp['s']).name
                    del ftemp['s']
                    if 'r' in ftemp:
                        brot = ftemp['r']
                        del ftemp['r']
                        if brot != 0:
                            value += '>>>%d' % brot
                elif operand == OT.siv:
                    r = Reg(ftemp['s'])
                    del ftemp['s']
                    blen_used = True
                    rb = r.rightmost_liv_bit()
                    if 'l' in ftemp:
                        l = ftemp['l']
                        if l == 0:
                            l = 8
                        lb = rb + l - 1
                    else:
                        lb = 7
                    value = 's' + r.name[:3]
                    if rb != 0 or lb != 7:
                        if rb == lb:
                            value += '[%d]' % lb
                        else:
                            value += '[%d:%d]' % (lb, rb)
                elif operand == OT.dr:
                    value = Reg(ftemp['d']).name
                    del ftemp['d']
                elif operand == OT.div:
                    r = Reg(ftemp['d'])
                    del ftemp['d']
                    blen_used = True
                    rb = r.rightmost_liv_bit()
                    if 'l' in ftemp:
                        l = ftemp['l']
                        if l == 0:
                            l = 8
                        lb = rb + l - 1
                    else:
                        lb = 7
                    value = 'd' + r.name[:3]
                    if rb != 0 or lb != 7:
                        if rb == lb:
                            value += '[%d]' % lb
                        else:
                            value += '[%d:%d]' % (lb, rb)
                elif operand == OT.imm:
                    value = self.ihex(ftemp['i'])
                    del ftemp['i']
                elif operand in [OT.jmp5, OT.jmp8, OT.jmp13]:
                    target = ftemp['j']
                    del ftemp['j']
                    if target in symtab_by_value:
                        value = symtab_by_value[target]
                    else:
                        value = self.ihex(target)
                else:
                    raise NotImplementedError('operand type ' + operand)
                operands.append(value)
            if blen_used:
                del ftemp['l']
            if ftemp:
                raise NotImplementedError('leftover fields: ' + str(ftemp))

        return s, ','.join(operands), fields


    def __init__(self, cpu_type = CpuType.s8x300):
        self.cpu_type = cpu_type
        self.__opcode_init()

#if __name__ == '__main__':
#    s8x30x = S8X30x()
#    s8x30x._opcode_table_print()

# ======== End of s8x30x.py

# Addresses 0x30 to 0x3f are the 8x320 dual port register gadget
# Addresses 0x48 to 0x5f are the 8x330 disk controller
#
# Details on the 8x330
#  0x48 to 0x57 is a 16 item "register file" (scratch)
#
#  0x48 - 0x49 are not used
#  0x4a is used (scratchpad)
#  0x4b is used (scratchpad)
#  0x4c is used (scratchpad)
#  0x4d is used (scratchpad)
#  0x4e is used (scratchpad)
#  0x4f is used (scratchpad)
#  0x50 is used (scratchpad)
#  0x51 is used (scratchpad)
#  0x52 is used (scratchpad)
#  0x53 - 0x57 is not used
#
#  0x58 (130) = - not used (invalid)
#  0x59 (131) = - not used (invalid)
#  0x5a (132) = Command and status 1
#  0x5b (133) = Command and status 2
#  0x5c (134) = Command and status 3
#  0x5d (135) = Command and status 4
#  0x5e (136) = Sector length (never used)
#  0x5f (137) = Data
#
# -----
# We also have 3 8 bit IO ports (N8x36 devices)
# We read and write address 1 (data port)
# We only write to address 2
# We only read from address 3
# -----

aws_regs = {
    '1h' : "DATA_PORT",
    '2h' : "CMD_PORT",
    '3h' : "STATUS_PORT",

    '30h' : "IO-0",
    '31h' : "IO-1",
    '32h' : "IO-2",
    '33h' : "IO-3",
    '34h' : "IO-4",
    '35h' : "IO-5",
    '36h' : "IO-6",
    '37h' : "IO-7",
    '38h' : "IO-8",
    '39h' : "IO-9",
    '3ah' : "IO-A",
    '3bh' : "IO-B",
    '3ch' : "IO-C",
    '3dh' : "IO-D",
    '3eh' : "IO-E",
    '3fh' : "IO-F",

    '48h' : "VAL0",    # never used
    '49h' : "VAL1",    # never used
    '4ah' : "VAL2",
    '4bh' : "VAL3",
    '4ch' : "VAL4",
    '4dh' : "VAL5",
    '4eh' : "VAL6",
    '4fh' : "VAL7",
    '50h' : "VAL8",
    '51h' : "VAL9",
    '52h' : "VALA",
    '53h' : "VALB",    # never used
    '54h' : "VALC",    # never used
    '55h' : "VALD",    # never used
    '56h' : "VALE",    # never used
    '57h' : "VALF",    # never used

    '5ah' : "CSR1",
    '5bh' : "CSR2",
    '5ch' : "CSR3",
    '5dh' : "CSR4",

    '5eh' : "SECTOR_LEN",
    '5fh' : "FDC_DATA"
}

# We only ever see this for IVL
# xmit    0ffh,ivl

# add mnemonic names for the AWS "turbo" disk controller
# We want to change lines like this:
#  xmit    5ch,ivr
# to this:
#  xmit    CSR3,ivr

def aws_equ_list () :
    for (k,v) in aws_regs.items() :
        print ( '%-12s EQU  %s' % (v, k) )
    print ( "" )

def aws_augment ( dis, operands ) :
    # For some crazy reason "dis" is padded with blanks at the end
    # to a constant width of 6 characters
    #print ( "--" + dis + "--" )
    if dis != "xmit  " :
        return operands, None
    ops = operands.split ( ',' )
    if ops[1] != "ivr" :
        return operands, None
    if not ops[0] in aws_regs :
        #return operands, "  # XXXXXXXXXXXXXXXX"
        return operands, None
    return aws_regs[ops[0]] + "," + ops[1], None

# Pass 1 just builds symbol table (list of jump targets).
def pass1(s8x30x, fw, base):
    symtab_by_value = {}
    for pc in range(base, base+len(fw)):
        (dis, operands, fields) = s8x30x.disassemble_inst(fw, pc, disassemble_operands = False)
        if 'j' in fields:
            symtab_by_value[fields['j']] = 'x%04x' % fields['j']
    return symtab_by_value


def pass2(s8x30x, fw, base, symtab_by_value, show_obj = False, output_file = sys.stdout):
    for pc in range(base, base+len(fw)):
        (dis, operands, fields) = s8x30x.disassemble_inst(fw, pc, symtab_by_value)

        s = ''
        if show_obj:
            s += '%04x: '% pc
            for i in range(len(fw[pc])):
                s += '%02x ' % fw[pc][i]

        if pc in symtab_by_value:
            label = symtab_by_value[pc] + ':'
        else:
            label = ''

        # - apply fast IO decoder
        # This requires some kind of third file to designate things.
        #if fast_io_decoder is not None and len(fw[pc]) > 2:
        #    operands = fast_io_decoder.fast_io_decode(fw[pc][2:], operands)

        # Added by tjt 2-17-2021
        # we don't need "fields"
        (operands, tail) = aws_augment ( dis, operands )
            
        s += '%-8s%-8s%s' % (label, dis, operands)
        if tail :
            s += tail

        ##print ( "-- addr = " + hex(pc) )
        #print ( "DIS: " + dis )
        #print ( operands )
        #print ( fields )

        output_file.write(s + '\n')
    

def disassemble(s8x30x, fw, show_obj = False, output_file = sys.stdout, base = 0):

    aws_equ_list ()

    symtab_by_value = pass1(s8x30x, fw, base)
    #symtab_by_name = { v: k for k, v in symtab_by_value.items() }
    pass2(s8x30x, fw, base, symtab_by_value, show_obj = show_obj, output_file = output_file)


# type function for argparse to support numeric arguments in hexadecimal
# ("0x" prefix) as well as decimal (no prefix)
def auto_int(x):
    return int(x, 0)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description = 'Disassembler for Signetics 8X300/8X305')

    parser.add_argument('-l', '--listing', action='store_true',
                        help = 'generate output in listing format')

    fmt_group = parser.add_mutually_exclusive_group()
    fmt_group.add_argument('--binary', action='store_const',
                           dest='inputformat',
                           const='binary',
                           help = 'input file format is raw binary (default)')

#    fmt_group.add_argument('--hex', action='store_const',
#                           dest='inputformat',
#                           const='hex',
#                           help = 'input file format is Intel hex')

    cpu_type_group = parser.add_mutually_exclusive_group()
    cpu_type_group.add_argument('-0', '--8x300',
                                action='store_const',
                                dest='cpu_type',
                                const=CpuType.s8x300,
                                help = '8X300 processor')
    cpu_type_group.add_argument('-5', '--8x305',
                                action='store_const',
                                dest='cpu_type',
                                const=CpuType.s8x305,
                                help = '8X305 processor')
    
    fastio_group = parser.add_mutually_exclusive_group()

#    fmt_group.add_argument('--wd1000',
#                           action='store_const',
#                           dest='fastio',
#                           const='wd1000',
#                           help = 'decode WD1000 fast I/O select')
#    fmt_group.add_argument('--wd1001',
#                           action='store_const',
#                           dest='fastio',
#                           const='wd1001',
#                           help = 'decode WD1001 fast I/O select')

    parser.add_argument('-o', '--output', type=argparse.FileType('w'),
                        default = sys.stdout,
                        help = 'disassembly output file')

    parser.add_argument('input',
                        type = argparse.FileType('rb'),
                        nargs = '*',
                        help = 'input file(s), multiple files will be interleaved (useful for separate even, odd files)')

    args = parser.parse_args()
    #print(args)

    if len(args.input) < 2:
        print('Minimum two object files required', file = sys.stderr)
        sys.exit(2)

    if 'cpu_type' not in args:
        args.cpu_type = CpuType.s8x300

    s8x30x = S8X30x(cpu_type = args.cpu_type)

#    if args.inputformat == 'hex':
#        data = [IntelHex().read(d) for d in data]

    fast_io_decoder = None

#    if args.fastio == 'wd1000':
#        fast_io_decoder = WD1000()
#    elif args.fastio == 'wd1001':
#        fast_io_decoder = WD1001()

    fw = Memory ( args.input )

    disassemble(s8x30x, fw, show_obj = args.listing, output_file = args.output)

# THE END
