shithub: pokecrystal

Download patch

ref: 3bd84c1dac0dc7085287bd6a7c822dfa2663cf71
parent: f93de7b1bda5c4812dc178108019fbf8a8836dad
author: Bryan Bishop <[email protected]>
date: Sat Mar 24 14:01:37 EDT 2012

lots of asm-related code and tests

--- a/extras/crystal.py
+++ b/extras/crystal.py
@@ -4405,6 +4405,463 @@
         #set the value in the original dictionary
         map_names[map_group_id][map_id]["label"] = cleaned_name
 
+#### asm utilities ####
+#these are pulled in from pokered/extras/analyze_incbins.py
+
+#store each line of source code here
+asm = None
+
+#store each incbin line separately
+incbin_lines = []
+
+#storage for processed incbin lines
+processed_incbins = {}
+
+def isolate_incbins():
+    "find each incbin line"
+    global incbin_lines
+    incbin_lines = []
+    for line in asm:
+        if line == "": continue
+        if line.count(" ") == len(line): continue
+
+        #clean up whitespace at beginning of line
+        while line[0] == " ":
+            line = line[1:]
+
+        if line[0:6] == "INCBIN" and "baserom.gbc" in line:
+            incbin_lines.append(line)
+    return incbin_lines
+
+def process_incbins():
+    "parse incbin lines into memory"
+    global incbins
+    incbins = {} #reset
+    for incbin in incbin_lines:
+        processed_incbin = {}
+
+        line_number = asm.index(incbin)
+
+        partial_start = incbin[21:]
+        start = partial_start.split(",")[0].replace("$", "0x")
+        start = eval(start)
+        start_hex = hex(start).replace("0x", "$")
+
+        partial_interval = incbin[21:].split(",")[1]
+        partial_interval = partial_interval.replace(";", "#")
+        partial_interval = partial_interval.replace("$", "0x").replace("0xx", "0x")
+        interval = eval(partial_interval)
+        interval_hex = hex(interval).replace("0x", "$").replace("x", "")
+
+        end = start + interval
+        end_hex = hex(end).replace("0x", "$")
+
+        processed_incbin = {
+                            "line_number": line_number,
+                            "line": incbin,
+                            "start": start,
+                            "interval": interval,
+                            "end": end,
+                           }
+
+        #don't add this incbin if the interval is 0
+        if interval != 0:
+            processed_incbins[line_number] = processed_incbin
+
+def reset_incbins():
+    "reset asm before inserting another diff"
+    asm = None
+    incbin_lines = []
+    processed_incbins = {}
+    load_asm()
+    isolate_incbins()
+    process_incbins()
+
+def find_incbin_to_replace_for(address, debug=False, rom_file="../baserom.gbc"):
+    """returns a line number for which incbin to edit
+    if you were to insert bytes into main.asm"""
+    if type(address) == str: address = int(address, 16)
+    if not (0 <= address <= os.lstat(rom_file).st_size):
+        raise IndexError, "address is out of bounds"
+    for incbin_key in processed_incbins.keys():
+        incbin = processed_incbins[incbin_key]
+        start = incbin["start"]
+        end = incbin["end"]
+        if debug:
+            print "start is: " + str(start)
+            print "end is: " + str(end)
+            print "address is: " + str(type(address))
+            print "checking.... " + hex(start) + " <= " + hex(address) + " <= " + hex(end)
+        if start <= address <= end:
+            return incbin_key
+    return None
+
+def split_incbin_line_into_three(line, start_address, byte_count):
+    """
+    splits an incbin line into three pieces.
+    you can replace the middle one with the new content of length bytecount
+    
+    start_address: where you want to start inserting bytes
+    byte_count: how many bytes you will be inserting
+    """
+    if type(start_address) == str: start_address = int(start_address, 16)
+    if not (0 <= start_address <= os.lstat(rom_file).st_size):
+        raise IndexError, "start_address is out of bounds"
+    if len(processed_incbins) == 0:
+        raise Exception, "processed_incbins must be populated"
+
+    original_incbin = processed_incbins[line]
+    start = original_incbin["start"]
+    end = original_incbin["end"]
+
+    #start, end1, end2 (to be printed as start, end1 - end2)
+    if start_address - start > 0:
+        first = (start, start_address, start)
+    else:
+        first = (None) #skip this one because we're not including anything
+
+    #this is the one you will replace with whatever content
+    second = (start_address, byte_count)
+
+    third = (start_address + byte_count, end - (start_address + byte_count))
+
+    output = ""
+
+    if first:
+        output += "INCBIN \"baserom.gbc\",$" + hex(first[0])[2:] + ",$" + hex(first[1])[2:] + " - $" + hex(first[2])[2:] + "\n"
+    output += "INCBIN \"baserom.gbc\",$" + hex(second[0])[2:] + "," + str(byte_count) + "\n"
+    output += "INCBIN \"baserom.gbc\",$" + hex(third[0])[2:] + ",$" + hex(third[1])[2:] #no newline
+    return output
+
+def generate_diff_insert(line_number, newline):
+    original = "\n".join(line for line in asm)
+    newfile = deepcopy(asm)
+    newfile[line_number] = newline #possibly inserting multiple lines
+    newfile = "\n".join(line for line in newfile)
+
+    original_filename = "ejroqjfoad.temp"
+    newfile_filename = "fjiqefo.temp"
+
+    original_fh = open(original_filename, "w")
+    original_fh.write(original)
+    original_fh.close()
+
+    newfile_fh = open(newfile_filename, "w")
+    newfile_fh.write(newfile)
+    newfile_fh.close()
+
+    try:
+        diffcontent = subprocess.check_output("diff -u ../main.asm " + newfile_filename, shell=True)
+    except AttributeError, exc:
+        raise exc
+    except Exception, exc:
+        diffcontent = exc.output
+
+    os.system("rm " + original_filename)
+    os.system("rm " + newfile_filename)
+
+    return diffcontent
+
+def apply_diff(diff, try_fixing=True, do_compile=True):
+    print "... Applying diff."
+
+    #write the diff to a file
+    fh = open("temp.patch", "w")
+    fh.write(diff)
+    fh.close()
+
+    #apply the patch
+    os.system("cp ../main.asm ../main1.asm")
+    os.system("patch ../main.asm temp.patch")
+
+    #remove the patch
+    os.system("rm temp.patch")
+
+    #confirm it's working
+    if do_compile:
+        try:
+            subprocess.check_call("cd ../; make clean; LC_CTYPE=C make", shell=True)
+            return True
+        except Exception, exc:
+            if try_fixing:
+                os.system("mv ../main1.asm ../main.asm")
+            return False
+
+def index(seq, f):
+    """return the index of the first item in seq
+    where f(item) == True."""
+    return next((i for i in xrange(len(seq)) if f(seq[i])), None)
+
+def is_probably_pointer(input):
+    try:
+        blah = int(input, 16)
+        return True
+    except:
+        return False
+
+def analyze_intervals():
+    """find the largest baserom.gbc intervals"""
+    global asm, processed_incbins
+    if asm == None:
+        load_asm()
+    if processed_incbins == {}:
+        isolate_incbins()
+        process_incbins()
+    results = []
+    ordered_keys = sorted(processed_incbins, key=lambda entry: processed_incbins[entry]["interval"])
+    ordered_keys.reverse()
+    for key in ordered_keys:
+        results.append(processed_incbins[key])
+    return results
+
+def write_all_labels(all_labels):
+    fh = open("labels.json", "w")
+    fh.write(json.dumps(all_labels))
+    fh.close()
+
+def remove_quoted_text(line):
+    """get rid of content inside quotes
+    and also removes the quotes from the input string"""
+    while line.count("\"") % 2 == 0 and line.count("\"") > 0:
+        first = line.find("\"")
+        second = line.find("\"", first+1)
+        line = line[0:first] + line[second+1:]
+    while line.count("\'") % 2 == 0 and line.count("'") > 0:
+        first = line.find("\'")
+        second = line.find("\'", first+1)
+        line = line[0:first] + line[second+1:]
+    return line
+
+def line_has_comment_address(line, returnable={}):
+    """checks that a given line has a comment
+    with a valid address"""
+    #first set the bank/offset to nada
+    returnable["bank"] = None
+    returnable["offset"] = None
+    returnable["address"] = None
+    #only valid characters are 0-9A-F
+    valid = [str(x) for x in range(0,10)] + [chr(x) for x in range(97, 102+1)]
+    #check if there is a comment in this line
+    if ";" not in line:
+        return False
+    #first throw away anything in quotes
+    if (line.count("\"") % 2 == 0 and line.count("\"")!=0) \
+       or (line.count("\'") % 2 == 0 and line.count("\'")!=0):
+        line = remove_quoted_text(line)
+    #check if there is still a comment in this line after quotes removed
+    if ";" not in line:
+        return False
+    #but even if there's a semicolon there must be later text
+    if line[-1] == ";":
+        return False
+    #and just a space doesn't count
+    if line[-2:] == "; ":
+        return False
+    #and multiple whitespace doesn't count either
+    line = line.rstrip(" ")
+    if line[-1] == ";":
+        return False
+    #there must be more content after the semicolon
+    if len(line)-1 == line.find(";"):
+        return False
+    #split it up into the main comment part
+    comment = line[line.find(";")+1:]
+    #don't want no leading whitespace
+    comment = comment.lstrip(" ").rstrip(" ")
+    #split up multi-token comments into single tokens
+    token = comment
+    if " " in comment:
+        #use the first token in the comment
+        token = comment.split(" ")[0]
+    if token in ["0x", "$", "x", ":"]:
+        return False
+    bank, offset = None, None
+    #process a token with a A:B format
+    if ":" in token: #3:3F0A, $3:$3F0A, 0x3:0x3F0A, 3:3F0A
+        #split up the token
+        bank_piece = token.split(":")[0].lower()
+        offset_piece = token.split(":")[1].lower()
+        #filter out blanks/duds
+        if bank_piece in ["$", "0x", "x"] \
+        or offset_piece in ["$", "0x", "x"]:
+            return False
+        #they can't have both "$" and "x"
+        if "$" in bank_piece and "x" in bank_piece:
+            return False
+        if "$" in offset_piece and "x" in offset_piece:
+            return False
+        #process the bank piece
+        if "$" in bank_piece:
+            bank_piece = bank_piece.replace("$", "0x")
+        #check characters for validity?
+        for c in bank_piece.replace("x", ""):
+            if c not in valid:
+                return False
+        bank = int(bank_piece, 16)
+        #process the offset piece
+        if "$" in offset_piece:
+            offset_piece = offset_piece.replace("$", "0x")
+        #check characters for validity?
+        for c in offset_piece.replace("x", ""):
+            if c not in valid:
+                return False
+        offset = int(offset_piece, 16)
+    #filter out blanks/duds
+    elif token in ["$", "0x", "x"]:
+        return False
+    #can't have both "$" and "x" in the number
+    elif "$" in token and "x" in token:
+        return False
+    elif "x" in token and not "0x" in token: #it should be 0x
+        return False
+    elif "$" in token and not "x" in token:
+        token = token.replace("$", "0x")
+        offset = int(token, 16)
+        bank = calculate_bank(offset)
+    elif "0x" in token and not "$" in token:
+        offset = int(token, 16)
+        bank = calculate_bank(offset)
+    else: #might just be "1" at this point
+        token = token.lower()
+        #check if there are bad characters
+        for c in token:
+            if c not in valid:
+                return False
+        offset = int(token, 16)
+        bank = calculate_bank(offset)
+    if offset == None and bank == None:
+        return False
+    returnable["bank"] = bank
+    returnable["offset"] = offset
+    returnable["address"] = calculate_pointer(offset, bank=bank)
+    return True
+def line_has_label(line):
+    """returns True if the line has an asm label"""
+    if not isinstance(line, str):
+        raise Exception, "can't check this type of object"
+    line = line.rstrip(" ").lstrip(" ")
+    line = remove_quoted_text(line)
+    if ";" in line:
+        line = line.split(";")[0]
+    if 0 <= len(line) <= 1:
+        return False
+    if ":" not in line:
+        return False
+    if line[0] == ";":
+        return False
+    if line[0] == "\"":
+        return False
+    if "::" in line:
+        return False
+    return True
+def get_label_from_line(line):
+    """returns the label from the line"""
+    #check if the line has a label
+    if not line_has_label(line):
+        return None
+    #split up the line
+    label = line.split(":")[0]
+    return label
+def find_labels_without_addresses():
+    """scans the asm source and finds labels that are unmarked"""
+    without_addresses = []
+    for (line_number, line) in enumerate(asm):
+        if line_has_label(line):
+            label = get_label_from_line(line)
+            if not line_has_comment_address(line):
+                without_addresses.append({"line_number": line_number, "line": line, "label": label})
+    return without_addresses
+
+label_errors = ""
+def get_labels_between(start_line_id, end_line_id, bank_id):
+    labels = []
+    #label = {
+    #   "line_number": 15,
+    #   "bank": 32,
+    #   "label": "PalletTownText1",
+    #   "offset": 0x5315,
+    #   "address": 0x75315,
+    #}
+    sublines = asm[start_line_id : end_line_id + 1]
+    for (current_line_offset, line) in enumerate(sublines):
+        #skip lines without labels
+        if not line_has_label(line): continue
+        #reset some variables
+        line_id = start_line_id + current_line_offset
+        line_label = get_label_from_line(line)
+        address = None
+        offset = None
+        #setup a place to store return values from line_has_comment_address
+        returnable = {}
+        #get the address from the comment
+        has_comment = line_has_comment_address(line, returnable=returnable)
+        #skip this line if it has no address in the comment
+        if not has_comment: continue
+        #parse data from line_has_comment_address
+        address = returnable["address"]
+        bank = returnable["bank"]
+        offset = returnable["offset"]
+        #dump all this info into a single structure
+        label = {
+            "line_number": line_id,
+            "bank": bank,
+            "label": line_label,
+            "offset": offset,
+            "address": address,
+        }
+        #store this structure
+        labels.append(label)
+    return labels
+
+def scan_for_predefined_labels():
+    """looks through the asm file for labels at specific addresses,
+    this relies on the label having its address after. ex:
+
+    ViridianCity_h: ; 0x18357 to 0x18384 (45 bytes) (bank=6) (id=1)
+    PalletTownText1: ; 4F96 0x18f96
+    ViridianCityText1: ; 0x19102
+
+    It would be more productive to use rgbasm to spit out all label
+    addresses, but faster to write this script. rgbasm would be able
+    to grab all label addresses better than this script..
+    """
+    bank_intervals = {}
+    all_labels = []
+
+    #figure out line numbers for each bank
+    for bank_id in range(0x7F+1):
+        abbreviation = ("%.x" % (bank_id)).upper()
+        abbreviation_next = ("%.x" % (bank_id+1)).upper()
+        if bank_id == 0:
+            abbreviation = "0"
+            abbreviation_next = "1"
+
+        start_line_id = index(asm, lambda line: "\"bank" + abbreviation + "\"" in line)
+
+        if bank_id != 0x2c:
+            end_line_id = index(asm, lambda line: "\"bank" + abbreviation_next + "\"" in line)
+        else:
+            end_line_id = len(asm) - 1
+
+        print "bank" + abbreviation + " starts at " + str(start_line_id) + " to " + str(end_line_id)
+
+        bank_intervals[bank_id] = {
+                                    "start": start_line_id,
+                                    "end": end_line_id,
+                                  }
+    for bank_id in bank_intervals.keys():
+        bank_data = bank_intervals[bank_id]
+
+        start_line_id = bank_data["start"]
+        end_line_id   = bank_data["end"]
+
+        labels = get_labels_between(start_line_id, end_line_id, bank_id)
+        #bank_intervals[bank_id]["labels"] = labels
+        all_labels.extend(labels)
+
+    write_all_labels(all_labels)
+    return all_labels
+
 #### generic testing ####
 
 class TestCram(unittest.TestCase):
@@ -4615,6 +5072,69 @@
         self.assertEquals(len(base), asm.length())
         self.assertEquals(len(base), len(list(asm)))
         self.assertEquals(len(asm), asm.length())
+    def test_remove_quoted_text(self):
+        x = remove_quoted_text
+        self.assertEqual(x("hello world"), "hello world")
+        self.assertEqual(x("hello \"world\""), "hello ")
+        input = 'hello world "testing 123"'
+        self.assertNotEqual(x(input), input)
+        input = "hello world 'testing 123'"
+        self.assertNotEqual(x(input), input)
+        self.failIf("testing" in x(input))
+    def test_line_has_comment_address(self):
+        x = line_has_comment_address
+        self.assertFalse(x(""))
+        self.assertFalse(x(";"))
+        self.assertFalse(x(";;;"))
+        self.assertFalse(x(":;"))
+        self.assertFalse(x(":;:"))
+        self.assertFalse(x(";:"))
+        self.assertFalse(x(" "))
+        self.assertFalse(x("".join(" " * 5)))
+        self.assertFalse(x("".join(" " * 10)))
+        self.assertFalse(x("hello world"))
+        self.assertFalse(x("hello_world"))
+        self.assertFalse(x("hello_world:"))
+        self.assertFalse(x("hello_world:;"))
+        self.assertFalse(x("hello_world: ;"))
+        self.assertFalse(x("hello_world: ; "))
+        self.assertFalse(x("hello_world: ;" + "".join(" " * 5)))
+        self.assertFalse(x("hello_world: ;" + "".join(" " * 10)))
+        self.assertTrue(x(";1"))
+        self.assertTrue(x(";F"))
+        self.assertTrue(x(";$00FF"))
+        self.assertTrue(x(";0x00FF"))
+        self.assertTrue(x("; 0x00FF"))
+        self.assertTrue(x(";$3:$300"))
+        self.assertTrue(x(";0x3:$300"))
+        self.assertTrue(x(";$3:0x300"))
+        self.assertTrue(x(";3:300"))
+        self.assertTrue(x(";3:FFAA"))
+        self.assertFalse(x('hello world "how are you today;0x1"'))
+        self.assertTrue(x('hello world "how are you today:0x1";1'))
+    def test_line_has_label(self):
+        x = line_has_label
+        self.assertTrue(x("hi:"))
+        self.assertTrue(x("Hello: "))
+        self.assertTrue(x("MyLabel: ; test xyz"))
+        self.assertFalse(x(":"))
+        self.assertFalse(x(";HelloWorld:"))
+        self.assertFalse(x("::::"))
+        self.assertFalse(x(":;:;:;:::"))
+    def test_get_label_from_line(self):
+        x = get_label_from_line
+        self.assertEqual(x("HelloWorld: "), "HelloWorld")
+        self.assertEqual(x("HiWorld:"), "HiWorld")
+        self.assertEqual(x("HiWorld"), None)
+    def test_find_labels_without_addresses(self):
+        global asm
+        asm = ["hello_world: ; 0x1", "hello_world2: ;"]
+        labels = find_labels_without_addresses()
+        self.failUnless(labels[0]["label"] == "hello_world2")
+        asm = ["hello world: ;1", "hello_world: ;2"]
+        labels = find_labels_without_addresses()
+        self.failUnless(len(labels) == 0)
+        asm = None
 class TestMapParsing(unittest.TestCase):
     #def test_parse_warp_bytes(self):
     #    pass #or raise NotImplementedError, bryan_message