shithub: pokecrystal

ref: e59fa73c95d11aaee8a5759baea3ed0b3fde2741
dir: /extras/gfx.py/

View raw version
# -*- coding: utf-8 -*-

import os
import png
import argparse
from math import sqrt, floor, ceil

from crystal import load_rom

from pokemon_constants import pokemon_constants
from trainers import trainer_group_names


if __name__ != "__main__":
	rom = load_rom()


def mkdir_p(path):
	try:
		os.makedirs(path)
	except OSError as exc: # Python >2.5
		if exc.errno == errno.EEXIST:
			pass
		else: raise


def hex_dump(input, debug = True):
	"""display hex dump in rows of 16 bytes"""
	
	dump = ''
	output = ''
	stream = ''
	address = 0x00
	margin = 2 + len(hex(len(input))[2:])
	
	# dump
	for byte in input:
		cool = hex(byte)[2:].zfill(2)
		dump += cool + ' '
		if debug: stream += cool
	
	# convenient for testing quick edits in bgb
	if debug: output += stream + '\n'
	
	# get dump info
	bytes_per_line = 16
	chars_per_byte = 3 # '__ '
	chars_per_line = bytes_per_line * chars_per_byte
	num_lines = int(ceil(float(len(dump)) / float(chars_per_line)))
	
	# top
	# margin
	for char in range(margin):
		output += ' '
	# 
	for byte in range(bytes_per_line):
		output += hex(byte)[2:].zfill(2) + ' '
	output = output[:-1] # last space
	
	# print hex
	for line in range(num_lines):
		# address
		output += '\n' + hex(address)[2:].zfill(margin - 2) + ': '
		# contents
		start = line * chars_per_line
		end = chars_per_line + start - 1 # ignore last space
		output += dump[start:end]
		address += 0x10
		
	return output
	

def get_tiles(image):
	"""split a 2bpp image into 8x8 tiles"""
	tiles = []
	tile = []
	bytes_per_tile = 16
	
	cur_byte = 0
	for byte in image:
		# build tile
		tile.append(byte)
		cur_byte += 1
		# done building?
		if cur_byte >= bytes_per_tile:
			# push completed tile
			tiles.append(tile)
			tile = []
			cur_byte = 0
	return tiles


def connect(tiles):
	"""combine 8x8 tiles into a 2bpp image"""
	out = []
	for tile in tiles:
		for byte in tile:
			out.append(byte)
	return out
	

def transpose(tiles):
	"""transpose a tile arrangement along line y=x"""
	
	#     horizontal    <->     vertical
	# 00 01 02 03 04 05     00 06 0c 12 18 1e
	# 06 07 08 09 0a 0b     01 07 0d 13 19 1f
	# 0c 0d 0e 0f 10 11 <-> 02 08 0e 14 1a 20
	# 12 13 14 15 16 17 <-> 03 09 0f 15 1b 21
	# 18 19 1a 1b 1c 1d     04 0a 10 16 1c 22
	# 1e 1f 20 21 22 23     05 0b 11 17 1d 23
	# etc
	
	flipped = []
	t = 0 # which tile we're on
	w = int(sqrt(len(tiles))) # assume square image
	for tile in tiles:
		flipped.append(tiles[t])
		t += w
		# end of row?
		if t >= w*w:
			# wrap around
			t -= w*w
			# next row
			t += 1
	return flipped


def to_file(filename, data):
	file = open(filename, 'wb')
	for byte in data:
		file.write('%c' % byte)
	file.close()




# basic rundown of crystal's compression scheme:

# a control command consists of
# the command (bits 5-7)
# and the count (bits 0-4)
# followed by additional params

lz_lit = 0
# print literal for [count] bytes

lz_iter = 1
# print one byte [count] times

lz_alt = 2
# print alternating bytes (2 params) for [count] bytes

lz_zeros = 3
# print 00 for [count] bytes

# repeater control commands have a signed parameter used to determine the start point
# wraparound is simulated
# positive values are added to the start address of the decompressed data
# and negative values are subtracted from the current position

lz_repeat = 4
# print [count] bytes from decompressed data

lz_flip = 5
# print [count] bytes from decompressed data in bit order 01234567

lz_reverse = 6
# print [count] bytes from decompressed data backwards

lz_hi = 7
# -used when the count exceeds 5 bits. uses a 10-bit count instead
# -bits 2-4 now contain the control code, bits 0-1 are bits 8-9 of the count
# -the following byte contains bits 0-7 of the count

lz_end = 0xff
# if 0xff is encountered the decompression ends

# since frontpics have animation tiles lumped onto them,
# sizes must be grabbed from base stats to know when to stop reading them

max_length = 1 << 10 # can't go higher than 10 bits
lowmax = 1 << 5 # standard 5-bit param


class Compressed:
	"""compress 2bpp data"""
	
	def __init__(self, image = None, mode = 'horiz', size = None):
	
		assert image, 'need something to compress!'
		self.image = image
		self.pic = []
		self.animtiles = []
		
		# only transpose pic (animtiles were never transposed in decompression)
		if size != None:
			for byte in range((size*size)*16):
				self.pic += image[byte]
			for byte in range(((size*size)*16),len(image)):
				self.animtiles += image[byte]
		else:
			self.pic = image
		
		if mode == 'vert':
			self.tiles = get_tiles(self.pic)
			self.tiles = transpose(self.tiles)
			self.pic = connect(self.tiles)
		
		self.image = self.pic + self.animtiles
		
		self.end = len(self.image)
		
		self.byte = None
		self.address = 0
		
		self.stream = []
		
		self.zeros = []
		self.alts = []
		self.iters = []
		self.repeats = []
		self.flips = []
		self.reverses = []
		self.literals = []
		
		self.output = []
		
		self.compress()
	
	
	def compress(self):
		"""incomplete, but outputs working compressed data"""
		
		self.address = 0
		
		# todo
		#self.scanRepeats()
		
		while ( self.address < self.end ):
			
			#if (self.repeats):
			#	self.doRepeats()
			
			#if (self.flips):
			#	self.doFlips()
			
			#if (self.reverses):
			#	self.doReverses
			
			if (self.checkWhitespace()):
				self.doLiterals()
				self.doWhitespace()
			
			elif (self.checkIter()):
				self.doLiterals()
				self.doIter()
			
			elif (self.checkAlts()):
				self.doLiterals()
				self.doAlts()
			
			else: # doesn't fit any pattern -> literal
				self.addLiteral()
				self.next()
			
			self.doStream()
		
		# add any literals we've been sitting on
		self.doLiterals()
		
		# done
		self.output.append(lz_end)
	
	
	def getCurByte(self):
		if self.address < self.end:
			self.byte = ord(self.image[self.address])
		else: self.byte = None
	
	def next(self):
		self.address += 1
		self.getCurByte()
	
	def addLiteral(self):
		self.getCurByte()
		self.literals.append(self.byte)
		if len(self.literals) > max_length:
			raise Exception, "literals exceeded max length and the compressor didn't catch it"
		elif len(self.literals) == max_length:
			self.doLiterals()
	
	def doLiterals(self):
		if len(self.literals) > lowmax:
			self.output.append( (lz_hi << 5) | (lz_lit << 2) | ((len(self.literals) - 1) >> 8) )
			self.output.append( (len(self.literals) - 1) & 0xff )
		elif len(self.literals) > 0:
			self.output.append( (lz_lit << 5) | (len(self.literals) - 1) )
		for byte in self.literals:
			self.output.append(byte)
		self.literals = []	
	
	def doStream(self):
		for byte in self.stream:
			self.output.append(byte)
		self.stream = []
	
	
	def scanRepeats(self):
		"""works, but doesn't do flipped/reversed streams yet
		
		this takes up most of the compress time and only saves a few bytes
		it might be more feasible to exclude it entirely"""
		
		self.repeats = []
		self.flips = []
		self.reverses = []
		
		# make a 5-letter word list of the sequence
		letters = 5 # how many bytes it costs to use a repeat over a literal
		# any shorter and it's not worth the trouble
		num_words = len(self.image) - letters
		words = []
		for i in range(self.address,num_words):
			word = []
			for j in range(letters):
				word.append( ord(self.image[i+j]) )
			words.append((word, i))
		
		zeros = []
		for zero in range(letters):
			zeros.append( 0 )
		
		# check for matches
		def get_matches():
		# TODO:
		# append to 3 different match lists instead of yielding to one
		#
		#flipped = []
		#for byte in enumerate(this[0]):
		#	flipped.append( sum(1<<(7-i) for i in range(8) if (this[0][byte])>>i&1) )
		#reversed = this[0][::-1]
		#
			for whereabout, this in enumerate(words):
				for that in range(whereabout+1,len(words)):
					if words[that][0] == this[0]:
						if words[that][1] - this[1] >= letters:
							# remove zeros
							if this[0] != zeros:
								yield [this[0], this[1], words[that][1]]
		
		matches = list(get_matches())
		
		# remove more zeros
		buffer = []
		for match in matches:
			# count consecutive zeros in a word
			num_zeros = 0
			highest = 0
			for j in range(letters):
				if match[0][j] == 0:
					num_zeros += 1
				else:
					if highest < num_zeros: highest = num_zeros
					num_zeros = 0
			if highest < 4:
				# any more than 3 zeros in a row isn't worth it
				# (and likely to already be accounted for)
				buffer.append(match)
		matches = buffer
		
		# combine overlapping matches
		buffer = []
		for this, match in enumerate(matches):
			if this < len(matches) - 1: # special case for the last match
				if matches[this+1][1] <= (match[1] + len(match[0])): # check overlap
					if match[1] + len(match[0]) < match[2]:
						# next match now contains this match's bytes too
						# this only appends the last byte (assumes overlaps are +1
						match[0].append(matches[this+1][0][-1])
						matches[this+1] = match
					elif match[1] + len(match[0]) == match[2]:
						# we've run into the thing we matched
						buffer.append(match)
					# else we've gone past it and we can ignore it
				else: # no more overlaps
					buffer.append(match)	
			else: # last match, so there's nothing to check
				buffer.append(match) 
		matches = buffer
		
		# remove alternating sequences
		buffer = []
		for match in matches:
			for i in range(6 if letters > 6 else letters): 
				if match[0][i] != match[0][i&1]:
					buffer.append(match)
					break
		matches = buffer
		
		self.repeats = matches
		
	
	def doRepeats(self):
		"""doesn't output the right values yet"""
		
		unusedrepeats = []
		for repeat in self.repeats:
			if self.address >= repeat[2]:
				
				# how far in we are
				length = (len(repeat[0]) - (self.address - repeat[2]))
				
				# decide which side we're copying from
				if (self.address - repeat[1]) <= 0x80:
					self.doLiterals()
					self.stream.append( (lz_repeat << 5) | length - 1 )
					
					# wrong?
					self.stream.append( (((self.address - repeat[1])^0xff)+1)&0xff )

				else:
					self.doLiterals()
					self.stream.append( (lz_repeat << 5) | length - 1 )
					
					# wrong?
					self.stream.append(repeat[1]>>8)
					self.stream.append(repeat[1]&0xff)
				
				#print hex(self.address) + ': ' + hex(len(self.output)) + ' ' + hex(length)
				self.address += length
				
			else: unusedrepeats.append(repeat)
				
		self.repeats = unusedrepeats
	
	
	def checkWhitespace(self):
		self.zeros = []
		self.getCurByte()
		original_address = self.address
		
		if ( self.byte == 0 ):
			while ( self.byte == 0 ) & ( len(self.zeros) <= max_length ):
				self.zeros.append(self.byte)
				self.next()
			if len(self.zeros) > 1:
				return True
		self.address = original_address
		return False
	
	def doWhitespace(self):
		if (len(self.zeros) + 1) >= lowmax:
			self.stream.append( (lz_hi << 5) | (lz_zeros << 2) | ((len(self.zeros) - 1) >> 8) )
			self.stream.append( (len(self.zeros) - 1) & 0xff )
		elif len(self.zeros) > 1:
			self.stream.append( lz_zeros << 5 | (len(self.zeros) - 1) )
		else:
			raise Exception, "checkWhitespace() should prevent this from happening"
	
	
	def checkAlts(self):
		self.alts = []
		self.getCurByte()
		original_address = self.address
		num_alts = 0
		
		# make sure we don't check for alts at the end of the file
		if self.address+2 >= self.end: return False
		
		self.alts.append(self.byte)
		self.alts.append(ord(self.image[self.address+1]))
		
		# are we onto smething?
		if ( ord(self.image[self.address+2]) == self.alts[0] ):
			cur_alt = 0
			while (ord(self.image[(self.address)+1]) == self.alts[num_alts&1]) & (num_alts <= max_length):
				num_alts += 1
				self.next()
			# include the last alternated byte
			num_alts += 1
			self.address = original_address
			if num_alts > lowmax:
				return True
			elif num_alts > 2:
				return True
		return False
	
	def doAlts(self):
		original_address = self.address
		self.getCurByte()
		
		#self.alts = []
		#num_alts = 0
		
		#self.alts.append(self.byte)
		#self.alts.append(ord(self.image[self.address+1]))
		
		#i = 0
		#while (ord(self.image[self.address+1]) == self.alts[i^1]) & (num_alts <= max_length):
		#	num_alts += 1
		#	i ^=1
		#	self.next()
		## include the last alternated byte
		#num_alts += 1
		
		num_alts = len(self.iters) + 1
		
		if num_alts > lowmax:
			self.stream.append( (lz_hi << 5) | (lz_alt << 2) | ((num_alts - 1) >> 8) )
			self.stream.append( num_alts & 0xff )
			self.stream.append( self.alts[0] )
			self.stream.append( self.alts[1] )
		elif num_alts > 2:
			self.stream.append( (lz_alt << 5) | (num_alts - 1) )
			self.stream.append( self.alts[0] )
			self.stream.append( self.alts[1] )
		else:
			raise Exception, "checkAlts() should prevent this from happening"
		
		self.address = original_address
		self.address += num_alts
	

	def checkIter(self):
		self.iters = []
		self.getCurByte()
		iter = self.byte
		original_address = self.address
		while (self.byte == iter) & (len(self.iters) < max_length):
			self.iters.append(self.byte)
			self.next()
		self.address = original_address
		if len(self.iters) > 3:
			# 3 or fewer isn't worth the trouble and actually longer
			# if part of a larger literal set
			return True
		
		return False
	
	def doIter(self):
		self.getCurByte()
		iter = self.byte
		original_address = self.address
		
		self.iters = []
		while (self.byte == iter) & (len(self.iters) < max_length):
			self.iters.append(self.byte)
			self.next()
		
		if (len(self.iters) - 1) >= lowmax:
			self.stream.append( (lz_hi << 5) | (lz_iter << 2) | ((len(self.iters)-1) >> 8) )
			self.stream.append( (len(self.iters) - 1) & 0xff )
			self.stream.append( iter )
		elif len(self.iters) > 3:
			# 3 or fewer isn't worth the trouble and actually longer
			# if part of a larger literal set
			self.stream.append( (lz_iter << 5) | (len(self.iters) - 1) )
			self.stream.append( iter )
		else:
			self.address = original_address
			raise Exception, "checkIter() should prevent this from happening"





class Decompressed:
	"""parse compressed 2bpp data
	
	parameters:
		[compressed 2bpp data]
		[tile arrangement] default: 'vert'
		[size of pic] default: None
		[start] (optional)
	
	splits output into pic [size] and animation tiles if applicable
	data can be fed in from rom if [start] is specified"""
	
	def __init__(self, lz = None, mode = None, size = None, start = 0):
		# todo: play nice with Compressed
	
		assert lz, 'need something to compress!'
		self.lz = lz
		
		self.byte = None
		self.address = 0
		self.start = start
		
		self.output = []
		
		self.decompress()
		
		debug = False
		# print tuple containing start and end address
		if debug: print '(' + hex(self.start) + ', ' + hex(self.start + self.address+1) + '),'
		
		# only transpose pic
		self.pic = []
		self.animtiles = []
		
		if size != None:
			self.tiles = get_tiles(self.output)
			self.pic = connect(self.tiles[:(size*size)])
			self.animtiles = connect(self.tiles[(size*size):])
		else: self.pic = self.output
		
		if mode == 'vert':
			self.tiles = get_tiles(self.pic)
			self.tiles = transpose(self.tiles)
			self.pic = connect(self.tiles)
		
		self.output = self.pic + self.animtiles
	
	
	def decompress(self):
		"""replica of crystal's decompression"""
		
		self.output = []
		
		while True:
			self.getCurByte()
			
			if (self.byte == lz_end):
				break
			
			self.cmd = (self.byte & 0b11100000) >> 5
			
			if self.cmd == lz_hi: # 10-bit param
				self.cmd = (self.byte & 0b00011100) >> 2
				self.length = (self.byte & 0b00000011) << 8
				self.next()
				self.length += self.byte + 1
			else: # 5-bit param
				self.length = (self.byte & 0b00011111) + 1
			
			# literals
			if self.cmd == lz_lit:
				self.doLiteral()
			elif self.cmd == lz_iter:
				self.doIter()
			elif self.cmd == lz_alt:
				self.doAlt()
			elif self.cmd == lz_zeros:
				self.doZeros()
				
			else: # repeaters
				self.next()
				if self.byte > 0x7f: # negative
					self.displacement = self.byte & 0x7f
					self.displacement = len(self.output) - self.displacement - 1
				else: # positive
					self.displacement = self.byte * 0x100
					self.next()
					self.displacement += self.byte
				
				if self.cmd == lz_flip:
					self.doFlip()
				elif self.cmd == lz_reverse:
					self.doReverse()
				else: # lz_repeat
					self.doRepeat()
			
			self.address += 1
			#self.next() # somewhat of a hack
	
	
	def getCurByte(self):
		self.byte = ord(self.lz[self.start+self.address])
	
	def next(self):
		self.address += 1
		self.getCurByte()
	
	def doLiteral(self):
		# copy 2bpp data directly
		for byte in range(self.length):
			self.next()
			self.output.append(self.byte)
		
	def doIter(self):
		# write one byte repeatedly
		self.next()
		for byte in range(self.length):
			self.output.append(self.byte)
		
	def doAlt(self):
		# write alternating bytes
		self.alts = []
		self.next()
		self.alts.append(self.byte)
		self.next()
		self.alts.append(self.byte)
		
		for byte in range(self.length):
			self.output.append(self.alts[byte&1])
		
	def doZeros(self):
		# write zeros
		for byte in range(self.length):
			self.output.append(0x00)
		
	def doFlip(self):
		# repeat flipped bytes from 2bpp output
		# eg  11100100 -> 00100111
		# quat 3 2 1 0 ->  0 2 1 3
		for byte in range(self.length):
			flipped = sum(1<<(7-i) for i in range(8) if self.output[self.displacement+byte]>>i&1)
			self.output.append(flipped)
		
	def doReverse(self):
		# repeat reversed bytes from 2bpp output
		for byte in range(self.length):
			self.output.append(self.output[self.displacement-byte])
		
	def doRepeat(self):
		# repeat bytes from 2bpp output
		for byte in range(self.length):
			self.output.append(self.output[self.displacement+byte])



sizes = [
	5, 6, 7, 5, 6, 7, 5, 6, 7, 5, 5, 7, 5, 5, 7, 5,
	6, 7, 5, 6, 5, 7, 5, 7, 5, 7, 5, 6, 5, 6, 7, 5,
	6, 7, 5, 6, 6, 7, 5, 6, 5, 7, 5, 6, 7, 5, 7, 5,
	7, 5, 7, 5, 7, 5, 7, 5, 7, 5, 7, 5, 6, 7, 5, 6,
	7, 5, 7, 7, 5, 6, 7, 5, 6, 5, 6, 6, 6, 7, 5, 7,
	5, 6, 6, 5, 7, 6, 7, 5, 7, 5, 7, 7, 6, 6, 7, 6,
	7, 5, 7, 5, 5, 7, 7, 5, 6, 7, 6, 7, 6, 7, 7, 7,
	6, 6, 7, 5, 6, 6, 7, 6, 6, 6, 7, 6, 6, 6, 7, 7,
	6, 7, 7, 5, 5, 6, 6, 6, 6, 5, 6, 5, 6, 7, 7, 7,
	7, 7, 5, 6, 7, 7, 5, 5, 6, 7, 5, 6, 7, 5, 6, 7,
	6, 6, 5, 7, 6, 6, 5, 7, 7, 6, 6, 5, 5, 5, 5, 7,
	5, 6, 5, 6, 7, 7, 5, 7, 6, 7, 5, 6, 7, 5, 5, 6,
	6, 5, 6, 6, 6, 6, 7, 6, 5, 6, 7, 5, 7, 6, 6, 7,
	6, 6, 5, 7, 5, 6, 6, 5, 7, 5, 6, 5, 6, 6, 5, 6,
	6, 7, 7, 6, 7, 7, 5, 7, 6, 7, 7, 5, 7, 5, 6, 6,
	6, 7, 7, 7, 7, 5, 6, 7, 7, 7, 5,
]

def make_sizes():
	"""front pics have specified sizes"""
	top = 251
	base_stats = 0x51424
	# print monster sizes
	address = base_stats + 0x11
	
	output = ''
	
	for id in range(top):
		size = (ord(rom[address])) & 0x0f
		if id % 16 == 0: output += '\n\t'
		output += str(size) + ', '
		address += 0x20
	
	print output



fxs = 0xcfcf6
num_fx = 40

def decompress_fx_by_id(id):
	address = fxs + id*4 # len_fxptr
	# get size
	num_tiles = ord(rom[address]) # # tiles
	# get pointer
	bank = ord(rom[address+1])
	address = (ord(rom[address+3]) << 8) + ord(rom[address+2])
	address = (bank * 0x4000) + (address & 0x3fff)
	# decompress
	fx = Decompressed(rom, 'horiz', num_tiles, address)
	return fx
	
def decompress_fx():
	for id in range(num_fx):
		fx = decompress_fx_by_id(id)
		filename = '../gfx/fx/' + str(id).zfill(3) + '.2bpp' # ../gfx/fx/039.2bpp
		to_file(filename, fx.pic)


num_pics = 2
front = 0
back = 1

monsters = 0x120000
num_monsters = 251

unowns = 0x124000
num_unowns = 26
unown_dex = 201

def decompress_monster_by_id(id = 0, type = front):
	# no unowns here
	if id + 1 == unown_dex: return None
	# get size
	if type == front:
		size = sizes[id]
	else: size = None
	# get pointer
	address = monsters + (id*2 + type)*3 # bank, address
	bank = ord(rom[address]) + 0x36 # crystal
	address = (ord(rom[address+2]) << 8) + ord(rom[address+1])
	address = (bank * 0x4000) + (address & 0x3fff)
	# decompress
	monster = Decompressed(rom, 'vert', size, address)
	return monster
	
def decompress_monsters(type = front):
	for id in range(num_monsters):
		# decompress
		monster = decompress_monster_by_id(id, type)
		if monster != None: # no unowns here
			if not type: # front
				filename = 'front.2bpp'
				folder = '../gfx/pics/' + str(id+1).zfill(3) + '/'
				to_file(folder+filename, monster.pic)
				filename = 'tiles.2bpp'
				folder = '../gfx/pics/' + str(id+1).zfill(3) + '/'
				to_file(folder+filename, monster.animtiles)
			else: # back
				filename = 'back.2bpp'
				folder = '../gfx/pics/' + str(id+1).zfill(3) + '/'
				to_file(folder+filename, monster.pic)


def decompress_unown_by_id(letter, type = front):
	# get size
	if type == front:
		size = sizes[unown_dex-1]
	else: size = None
	# get pointer
	address = unowns + (letter*2 + type)*3 # bank, address
	bank = ord(rom[address]) + 0x36 # crystal
	address = (ord(rom[address+2]) << 8) + ord(rom[address+1])
	address = (bank * 0x4000) + (address & 0x3fff)
	# decompress
	unown = Decompressed(rom, 'vert', size, address)
	return unown

def decompress_unowns(type = front):
	for letter in range(num_unowns):
		# decompress
		unown = decompress_unown_by_id(letter, type)
		
		if not type: # front
			filename = 'front.2bpp'
			folder = '../gfx/pics/' + str(unown_dex).zfill(3) + chr(ord('a') + letter) + '/'
			to_file(folder+filename, unown.pic)
			filename = 'tiles.2bpp'
			folder = '../gfx/anim/'
			to_file(folder+filename, unown.animtiles)
		else: # back
			filename = 'back.2bpp'
			folder = '../gfx/pics/' + str(unown_dex).zfill(3) + chr(ord('a') + letter) + '/'
			to_file(folder+filename, unown.pic)


trainers = 0x128000
num_trainers = 67

def decompress_trainer_by_id(id):
	# get pointer
	address = trainers + id*3 # bank, address
	bank = ord(rom[address]) + 0x36 # crystal
	address = (ord(rom[address+2]) << 8) + ord(rom[address+1])
	address = (bank * 0x4000) + (address & 0x3fff)
	# decompress
	trainer = Decompressed(rom, 'vert', None, address)
	return trainer

def decompress_trainers():
	for id in range(num_trainers):
		# decompress
		trainer = decompress_trainer_by_id(id)
		filename = '../gfx/trainers/' + str(id).zfill(3) + '.2bpp' # ../gfx/trainers/066.2bpp
		to_file(filename, trainer.pic)


# in order of use (sans repeats)
intro_gfx = [
	('logo', 0x109407),
	('001', 0xE641D), # tilemap
	('unowns', 0xE5F5D),
	('pulse', 0xE634D),
	('002', 0xE63DD), # tilemap
	('003', 0xE5ECD), # tilemap
	('background', 0xE5C7D),
	('004', 0xE5E6D), # tilemap
	('005', 0xE647D), # tilemap
	('006', 0xE642D), # tilemap
	('pichu_wooper', 0xE592D),
	('suicune_run', 0xE555D),
	('007', 0xE655D), # tilemap
	('008', 0xE649D), # tilemap
	('009', 0xE76AD), # tilemap
	('suicune_jump', 0xE6DED),
	('unown_back', 0xE785D),
	('010', 0xE764D), # tilemap
	('011', 0xE6D0D), # tilemap
	('suicune_close', 0xE681D),
	('012', 0xE6C3D), # tilemap
	('013', 0xE778D), # tilemap
	('suicune_back', 0xE72AD),
	('014', 0xE76BD), # tilemap
	('015', 0xE676D), # tilemap
	('crystal_unowns', 0xE662D),
	('017', 0xE672D), # tilemap
]

def decompress_intro():
	for name, address in intro_gfx:
		filename = '../gfx/intro/' + name + '.2bpp'
		gfx = Decompressed( rom, 'horiz', None, address )
		to_file(filename, gfx.output)


title_gfx = [
	('suicune', 0x10EF46),
	('logo', 0x10F326),
	('crystal', 0x10FCEE),
]

def decompress_title():
	for name, address in title_gfx:
		filename = '../gfx/title/' + name + '.2bpp'
		gfx = Decompressed( rom, 'horiz', None, address )
		to_file(filename, gfx.output)

def decompress_tilesets():
	tileset_headers = 0x4d596
	len_tileset = 15
	num_tilesets = 0x25
	for tileset in range(num_tilesets):
		ptr = tileset*len_tileset + tileset_headers
		address = (ord(rom[ptr])*0x4000) + (((ord(rom[ptr+1]))+ord(rom[ptr+2])*0x100)&0x3fff)
		tiles = Decompressed( rom, 'horiz', None, address )
		filename = '../gfx/tilesets/'+str(tileset).zfill(2)+'.2bpp'
		to_file( filename, tiles.output )
		#print '(' + hex(address) + ', '+ hex(address+tiles.address+1) + '),'

misc = [
	('player', 0x2BA1A, 'vert'),
	('dude', 0x2BBAA, 'vert'),
	('town_map', 0xF8BA0, 'horiz'),
	('pokegear', 0x1DE2E4, 'horiz'),
	('pokegear_sprites', 0x914DD, 'horiz'),
]
def decompress_misc():
	for name, address, mode in misc:
		filename = '../gfx/misc/' + name + '.2bpp'
		gfx = Decompressed( rom, mode, None, address )
		to_file(filename, gfx.output)

def decompress_all(debug = False):
	"""decompress all known compressed data in baserom"""
	
	if debug: print 'fronts'
	decompress_monsters(front)
	if debug: print 'backs'
	decompress_monsters(back)
	if debug: print 'unown fronts'
	decompress_unowns(front)
	if debug: print 'unown backs'
	decompress_unowns(back)
	
	if debug: print 'trainers'
	decompress_trainers()
	
	if debug: print 'fx'
	decompress_fx()
	
	if debug: print 'intro'
	decompress_intro()
	
	if debug: print 'title'
	decompress_title()
	
	if debug: print 'tilesets'
	decompress_tilesets()
	
	if debug: print 'misc'
	decompress_misc()
	
	return


def decompress_from_address(address, mode='horiz', filename = 'de.2bpp', size = None):
	"""write decompressed data from an address to a 2bpp file"""
	image = Decompressed(rom, mode, size, address)
	to_file(filename, image.pic)


def decompress_file(filein, fileout, mode = 'horiz', size = None):
	f = open(filein, 'rb')
	image = f.read()
	f.close()
	
	de = Decompressed(image, mode, size)
	
	to_file(fileout, de.pic)


def compress_file(filein, fileout, mode = 'horiz'):
	f = open(filein, 'rb')
	image = f.read()
	f.close()
	
	lz = Compressed(image, mode)
	
	to_file(fileout, lz.output)




def compress_monster_frontpic(id, fileout):
	mode = 'vert'
	
	fpic = '../gfx/pics/' + str(id).zfill(3) + '/front.2bpp'
	fanim = '../gfx/pics/' + str(id).zfill(3) + '/tiles.2bpp'
	
	pic = open(fpic, 'rb').read()
	anim = open(fanim, 'rb').read()
	image = pic + anim
	
	lz = Compressed(image, mode, sizes[id-1])
	
	out = '../gfx/pics/' + str(id).zfill(3) + '/front.lz'
	
	to_file(out, lz.output)



def get_uncompressed_gfx(start, num_tiles, filename):
	"""grab tiles directly from rom and write to file"""
	bytes_per_tile = 0x10
	length = num_tiles*bytes_per_tile
	end = start + length
	rom = load_rom()
	image = []
	for address in range(start,end):
		image.append(ord(rom[address]))
	to_file(filename, image)



def hex_to_rgb(word):
	red = word & 0b11111
	word >>= 5
	green = word & 0b11111
	word >>= 5
	blue = word & 0b11111
	return (red, green, blue)

def grab_palettes(address, length = 0x80):
	output = ''
	for word in range(length/2):
		color = ord(rom[address+1])*0x100 + ord(rom[address])
		address += 2
		color = hex_to_rgb(color)
		red = str(color[0]).zfill(2)
		green = str(color[1]).zfill(2)
		blue = str(color[2]).zfill(2)
		output += '\tRGB '+red+', '+green+', '+blue
		output += '\n'
	return output







def dump_monster_pals():
	rom = load_rom()
	
	pals = 0xa8d6
	pal_length = 0x4
	for mon in range(251):
		
		name     = pokemon_constants[mon+1].title().replace('_','')
		num      = str(mon+1).zfill(3)
		dir      = 'gfx/pics/'+num+'/'
		
		address  = pals + mon*pal_length*2
		
		
		pal_data = []
		for byte in range(pal_length):
			pal_data.append(ord(rom[address]))
			address += 1
		
		filename = 'normal.pal'
		to_file('../'+dir+filename, pal_data)
		
		spacing  = ' ' * (15 - len(name))
		#print name+'Palette:'+spacing+' INCBIN "'+dir+filename+'"'
		
		
		pal_data = []
		for byte in range(pal_length):
			pal_data.append(ord(rom[address]))
			address += 1
		
		filename = 'shiny.pal'
		to_file('../'+dir+filename, pal_data)
		
		spacing  = ' ' * (10 - len(name))
		#print name+'ShinyPalette:'+spacing+' INCBIN "'+dir+filename+'"'


def dump_trainer_pals():
	rom = load_rom()
	
	pals = 0xb0d2
	pal_length = 0x4
	for trainer in range(67):
		
		name = trainer_group_names[trainer+1]['constant'].title().replace('_','')
		num  = str(trainer).zfill(3)
		dir  = 'gfx/trainers/'
		
		address = pals + trainer*pal_length
		
		pal_data = []
		for byte in range(pal_length):
			pal_data.append(ord(rom[address]))
			address += 1
		
		filename = num+'.pal'
		to_file('../'+dir+filename, pal_data)
		
		spacing = ' ' * (12 - len(name))
		print name+'Palette:'+spacing+' INCBIN"'+dir+filename+'"'



def flatten(planar):
	"""
	Flattens planar 2bpp image data into a quaternary pixel map.
	"""
	strips = []
	for pair in range(len(planar)/2):
		bottom = ord(planar[(pair*2)  ])
		top    = ord(planar[(pair*2)+1])
		strip  = []
		for i in range(7,-1,-1):
			color = ((bottom >> i) & 1) + (((top >> i-1) if i > 0 else (top << 1-i)) & 2)
			strip.append(color)
		strips += strip
	return strips


def to_lines(image, width):
	"""
	Converts a tiled quaternary pixel map to lines of quaternary pixels.
	"""
	
	tile = 8 * 8
	
	# so we know how many strips of 8px we're putting into a line
	num_columns = width / 8
	# number of lines
	height = len(image) / width
	
	lines = []
	for cur_line in range(height):
		tile_row = int(cur_line / 8)
		line = []
		for column in range(num_columns):
			anchor = num_columns*tile_row*tile + column*tile + (cur_line%8)*8
			line += image[anchor:anchor+8]
		lines.append(line)
	return lines

def dmg2rgb(word):
	red = word & 0b11111
	word >>= 5
	green = word & 0b11111
	word >>= 5
	blue = word & 0b11111
	alpha = 255
	return ((red<<3)+0b100, (green<<3)+0b100, (blue<<3)+0b100, alpha)


def png_pal(filename):
	palette = []
	palette.append((255,255,255,255))
	with open(filename, 'rb') as pal_data:
		words = pal_data.read()
		dmg_pals = []
		for word in range(len(words)/2):
			dmg_pals.append(ord(words[word*2]) + ord(words[word*2+1])*0x100)
	for word in dmg_pals:
		palette.append(dmg2rgb(word))
	palette.append((000,000,000,255))
	return palette


def to_png(filein, fileout=None, pal_file=None, height=None, width=None):
	"""
	Takes a planar 2bpp graphics file and converts it to png.
	"""
	
	if fileout == None: fileout = '.'.join(filein.split('.')[:-1]) + '.png'
	
	image = open(filein, 'rb').read()
	
	num_pixels = len(image) * 4
	
	if num_pixels == 0: return 'empty image!'
	
	
	# unless the pic is square, at least one dimension should be given
	
	if width == None and height == None:
		width  = int(sqrt(num_pixels))
		height = width
	
	elif height == None:
		height = num_pixels / width

	elif width  == None:
		width  = num_pixels / height
	
	
	# but try to see if it can be made rectangular
	
	if width * height != num_pixels:
		
		# look for possible combos of width/height that would form a rectangle
		matches = []
		
		# this is pretty inefficient, and there is probably a simpler way
		for width in range(8,256+1,8): # we only want dimensions that fit in tiles
			height = num_pixels / width
			if height % 8 == 0:
				matches.append((width, height))
		
		# go for the most square image
		width, height = sorted(matches, key=lambda (x,y): x+y)[0] # favors height
	
	
	# if it can't, the only option is a width of 1 tile
	
	if width * height != num_pixels:
		width = 8
		height = num_pixels / width
	
	
	# if this still isn't rectangular, then the image isn't made of tiles
	
	# for now we'll just spit out a warning
	if width * height != num_pixels:
		print 'Warning! ' + fileout + ' is ' + width + 'x' + height + '(' + width*height + ' pixels),\n' +\
		       'but ' + filein + ' is ' + num_pixels + ' pixels!'
	
	
	# map it out
	
	lines = to_lines(flatten(image), width)
	
	
	if pal_file == None:
		palette   = None
		greyscale = True
		bitdepth  = 2
		inverse   = { 0:3, 1:2, 2:1, 3:0 }
		map       = [[inverse[pixel] for pixel in line] for line in lines]
		
	else: # gbc color
		palette   = png_pal(pal_file)
		greyscale = False
		bitdepth  = 8
		map       = [[pixel for pixel in line] for line in lines]
	
	
	w = png.Writer(width, height, palette=palette, compression = 9, greyscale = greyscale, bitdepth = bitdepth)
	with open(fileout, 'wb') as file:
		w.write(file, map)




def to_2bpp(filein, fileout=None, palout=None):
	"""
	Takes a png and converts it to planar 2bpp.
	"""
	
	if fileout == None: fileout = '.'.join(filein.split('.')[:-1]) + '.2bpp'
	
	with open(filein, 'rb') as file:

		r = png.Reader(file)	
		info  = r.asRGBA8()
		
		width     = info[0]
		height    = info[1]
		
		rgba      = list(info[2])
		greyscale = info[3]['greyscale']
	
	
	# commented out for the moment
	
	padding = { 'left':   0,
	            'right':  0,
	            'top':    0,
	            'bottom': 0, }
	
	#if width  % 8 != 0:
	#	padding['left']   =    int(ceil((width / 8 + 8 - width) / 2))
	#	padding['right']  =   int(floor((width / 8 + 8 - width) / 2))
	
	#if height % 8 != 0:
	#	padding['top']    =  int(ceil((height / 8 + 8 - height) / 2))
	#	padding['bottom'] = int(floor((height / 8 + 8 - height) / 2))
	
	
	# turn the flat values into something more workable
	
	pixel_length = 4 # rgba
	image   = []
	
	# while we're at it, let's size up the palette
	
	palette = []

	for line in rgba:
		newline = []
		for pixel in range(len(line)/pixel_length):
			i = pixel*pixel_length
			color = { 'r': line[i  ],
			          'g': line[i+1],
			          'b': line[i+2],
			          'a': line[i+3], }
			newline.append(color)
			if color not in palette: palette.append(color)
		image.append(newline)
	
	
	# sort by luminance, because we can
	
	def luminance(color):
		# this is actually in reverse, thanks to dmg/cgb palette ordering
		rough = { 'r':  4.7,
		          'g':  1.4,
		          'b': 13.8, }
		return sum(color[key] * -rough[key] for key in rough.keys())
	
	palette = sorted(palette, key = lambda x:luminance(x))
	
	# no palette fixing for now
	
	assert len(palette) <= 4, 'Palette should be 4 colors, is really ' + str(len(palette))
	

	# spit out new palette (disabled for now)
	
	def rgb_to_dmg(color):
		word =  (color['r'] / 8) << 10
		word += (color['g'] / 8) <<  5
		word += (color['b'] / 8)
		return word
	
	palout = None
	
	if palout != None:
		output = []
		for color in palette[1:3]:
			word = rgb_to_dmg(color)
			output.append(word>>8)
			output.append(word&0xff)
		to_file(palout, output)
	
	
	# create a new map consisting of quaternary color ids
	
	map = []
	if padding['top']: map += [0] * (width + padding['left'] + padding['right']) * padding['top']
	for line in image:
		if padding['left']: map += [0] * padding['left']
		for color in line:
			map.append(palette.index(color))
		if padding['right']: map += [0] * padding['right']
	if padding['bottom']: map += [0] * (width + padding['left'] + padding['right']) * padding['bottom']
	
	# split it into strips of 8, and make them planar
	
	num_columns = width / 8
	num_rows = height / 8
	
	tile = 8 * 8
	
	image = []
	for row in range(num_rows):
		for column in range(num_columns):
			for strip in range(tile / 8):
				anchor = row*num_columns*tile + column*tile/8 + strip*width
				line   = map[anchor:anchor+8]
				bottom = 0
				top    = 0
				for bit, quad in enumerate(line):
					bottom += (quad & 1) << (7-bit)
					top    += ((quad & 2) >> 1) << (7-bit)
				image.append(bottom)
				image.append(top)
	
	to_file(fileout, image)


def png_to_lz(filein):
	
	name = os.path.splitext(filein)[0]
	
	to_2bpp(filein)
	image = open(name+'.2bpp', 'rb').read()
	to_file(name+'.lz', Compressed(image).output)




def mass_to_png(debug=False):
	# greyscale
	for root, dirs, files in os.walk('../gfx/'):
		for name in files:
			if debug: print os.path.splitext(name), os.path.join(root, name)
			if os.path.splitext(name)[1] == '.2bpp':
				to_png(os.path.join(root, name))

def mass_to_colored_png(debug=False):
	# greyscale, unless a palette is detected
	for root, dirs, files in os.walk('../gfx/'):
		if 'pics' not in root and 'trainers' not in root:
			for name in files:
				if debug: print os.path.splitext(name), os.path.join(root, name)
				if os.path.splitext(name)[1] == '.2bpp':
					if name[:5]+'.pal' in files:
						to_png(os.path.join(root, name), None, os.path.join(root, name[:-5]+'.pal'))
					else:
						to_png(os.path.join(root, name))
	
	# only monster and trainer pics for now
	for root, dirs, files in os.walk('../gfx/pics/'):
		for name in files:
			if debug: print os.path.splitext(name), os.path.join(root, name)
			if os.path.splitext(name)[1] == '.2bpp':
				if 'normal.pal' in files:
					to_png(os.path.join(root, name), None, os.path.join(root, 'normal.pal'))
				else:
					to_png(os.path.join(root, name))
	for root, dirs, files in os.walk('../gfx/trainers/'):
		for name in files:
			if debug: print os.path.splitext(name), os.path.join(root, name)
			if os.path.splitext(name)[1] == '.2bpp':
				to_png(os.path.join(root, name), None, os.path.join(root, name[:-5]+'.pal'))


def mass_decompress(debug=False):
	for root, dirs, files in os.walk('../gfx/'):
		for file in files:
			if 'lz' in file:
				if '/pics' in root:
					if 'front' in file:
						id = root.split('pics/')[1][:3]
						if id != 'egg':
							with open(root+'/'+file, 'rb') as lz: de = Decompressed(lz.read(), 'vert', sizes[int(id)-1])
						else:
							with open(root+'/'+file, 'rb') as lz: de = Decompressed(lz.read(), 'vert', 4)
						to_file(root+'/'+'front.2bpp', de.pic)
						to_file(root+'/'+'tiles.2bpp', de.animtiles)
					elif 'back' in file:
						with open(root+'/'+file, 'rb') as lz: de = Decompressed(lz.read(), 'vert')
						to_file(root+'/'+'back.2bpp', de.output)
				elif '/trainers' in root or '/fx' in root:
					with open(root+'/'+file, 'rb') as lz: de = Decompressed(lz.read(), 'vert')
					to_file(root+'/'+file[:-3]+'.2bpp', de.output)
				else:
					with open(root+'/'+file, 'rb') as lz: de = Decompressed(lz.read())
					to_file(root+file[:-3]+'.2bpp', de.output)

def append_terminator_to_lzs(directory):
	# fix lzs that were extracted with a missing terminator
	for root, dirs, files in os.walk(directory):
		for file in files:
			if '.lz' in file:
				data = open(root+file,'rb').read()
				if data[-1] != chr(0xff):
					data += chr(0xff)
					new = open(root+file,'wb')
					new.write(data)
					new.close()




if __name__ == "__main__":
	parser = argparse.ArgumentParser()
	parser.add_argument('cmd',  nargs='?', metavar='cmd',  type=str)
	parser.add_argument('arg1', nargs='?', metavar='arg1', type=str)
	parser.add_argument('arg2', nargs='?', metavar='arg2', type=str)
	parser.add_argument('arg3', nargs='?', metavar='arg3', type=str)
	parser.add_argument('arg4', nargs='?', metavar='arg4', type=str)
	parser.add_argument('arg5', nargs='?', metavar='arg5', type=str)
	args = parser.parse_args()
	
	debug = False
	
	if args.cmd == 'dump-pngs':
		mass_to_colored_png()
	
	elif args.cmd == 'png-to-lz':
		# python gfx.py png-to-lz [--front anim(2bpp) | --vert] [png]
		
		# python gfx.py png-to-lz --front [anim(2bpp)] [png]
		if args.arg1 == '--front':

			# front.png and tiles.png are combined before compression,
			# so we have to pass in things like anim file and pic size
			name = os.path.splitext(args.arg3)[0]
			
			to_2bpp(name+'.png', name+'.2bpp')
			pic  = open(name+'.2bpp', 'rb').read()
			anim = open(args.arg2, 'rb').read()
			size = int(sqrt(len(pic)/16)) # assume square pic
			to_file(name+'.lz', Compressed(pic + anim, 'vert', size).output)
		
		
		# python gfx.py png-to-lz --vert [png]
		elif args.arg1 == '--vert':
			
			# others are vertically oriented (frontpics are always vertical)
			
			name = os.path.splitext(args.arg2)[0]
			
			to_2bpp(name+'.png', name+'.2bpp')
			pic = open(name+'.2bpp', 'rb').read()
			to_file(name+'.lz', Compressed(pic + anim, 'vert').output)
		
		
		# python gfx.py png-to-lz [png]
		else:
			
			# standard usage
			
			png_to_lz(args.arg1)
	
	elif args.cmd == 'png-to-2bpp':
		to_2bpp(args.arg1)
	
	
	elif args.cmd == 'de':
		# python gfx.py de [addr] [fileout] [mode]
		
		rom = load_rom()
		
		addr = int(args.arg1,16)
		fileout = args.arg2
		mode = args.arg3
		decompress_from_address(addr, fileout, mode)
		if debug: print 'decompressed to ' + args.arg2 + ' from ' + hex(int(args.arg1,16)) + '!'
		
	elif args.cmd == 'lz':
		# python gfx.py lz [filein] [fileout] [mode]
		filein = args.arg1
		fileout = args.arg2
		mode = args.arg3
		compress_file(filein, fileout, mode)
		if debug: print 'compressed ' + filein + ' to ' + fileout + '!'
	
	elif args.cmd == 'lzf':
		# python gfx.py lzf [id] [fileout]
		compress_monster_frontpic(int(args.arg1), args.arg2)
	
	elif args.cmd == 'un':
		# python gfx.py un [address] [num_tiles] [filename]
		rom = load_rom()
		get_uncompressed_gfx(int(args.arg1,16), int(args.arg2), args.arg3)
	
	elif args.cmd == 'pal':
		# python gfx.py pal [address] [length]
		rom = load_rom()
		print grab_palettes(int(args.arg1,16), int(args.arg2))
	
	elif args.cmd == 'png':
		
		if '.2bpp' in args.arg1:
			if args.arg3 == 'greyscale':
				to_png(args.arg1, args.arg2)
			else:
				to_png(args.arg1, args.arg2, args.arg3)
		
		elif '.png' in args.arg1:
			to_2bpp(args.arg1, args.arg2)
	
	elif args.cmd == 'mass-decompress':
		mass_decompress()
		if debug: print 'decompressed known gfx to pokecrystal/gfx/!'