package charactermanaj.graphics.io;

import java.awt.Graphics2D;
import java.awt.RenderingHints;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferInt;
import java.awt.image.WritableRaster;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * 複数レイヤー画像をPSD形式のデータとして作成する。
 * https://www.adobe.com/devnet-apps/photoshop/fileformatashtml/
 *
 * @author seraphy
 */
public final class PSDCreator {

	/**
	 * レイヤーデータ
	 */
	public static class LayerData {

		/**
		 * レイヤー名
		 */
		private String layerName;

		/**
		 * レイヤーの画像(TYPE_INT_ARGB限定)
		 */
		private BufferedImage image;

		public LayerData(String layerName, BufferedImage image) {
			this.layerName = layerName;
			this.image = image;
		}

		public String getLayerName() {
			return layerName;
		}

		public BufferedImage getImage() {
			return image;
		}
	}

	/**
	 * レイヤーとチャネルのペア
	 */
	static final class LayerChannelPair {

		private final LayerData layerData;

		private final int channel;

		public LayerChannelPair(LayerData layerData, int channel) {
			this.layerData = layerData;
			this.channel = channel;
		}

		public LayerData getLayerData() {
			return layerData;
		}

		public int getChannel() {
			return channel;
		}

		@Override
		public int hashCode() {
			final int prime = 31;
			int result = 1;
			result = prime * result + channel;
			result = prime * result + ((layerData == null) ? 0 : layerData.hashCode());
			return result;
		}

		@Override
		public boolean equals(Object obj) {
			if (this == obj)
				return true;
			if (obj == null)
				return false;
			if (getClass() != obj.getClass())
				return false;
			LayerChannelPair other = (LayerChannelPair) obj;
			if (channel != other.channel)
				return false;
			if (layerData == null) {
				if (other.layerData != null)
					return false;
			} else if (!layerData.equals(other.layerData))
				return false;
			return true;
		}
	}

	/**
	 * RLEで圧縮するか？
	 */
	private static boolean useRLECompression = true;

	public static boolean isUseRLECompression() {
		return useRLECompression;
	}

	public static void setUseRLECompression(boolean useRLECompression) {
		PSDCreator.useRLECompression = useRLECompression;
	}

	/**
	 * レンダリングヒントを使うか？
	 */
	private static boolean useRenderingHints = true;

	public static boolean isUseRenderingHints() {
		return useRenderingHints;
	}

	public static void setUseRenderingHints(boolean useRenderingHints) {
		PSDCreator.useRenderingHints = useRenderingHints;
	}

	/**
	 * レイヤーを指定してPSDデータを作成する
	 * @param layerDatas レイヤーのコレクション、順番に重ねられる
	 * @return PSDデータ
	 * @throws IOException
	 */
	public static byte[] createPSD(Collection<LayerData> layerDatas) throws IOException {
		if (layerDatas == null) {
			throw new NullPointerException("layerDatas is required.");
		}
		if (layerDatas.isEmpty()) {
			throw new IllegalArgumentException("layerDatas must not be empty.");
		}

		BufferedImage cimg = createCompositeImage(layerDatas);
		int width = cimg.getWidth();
		int height = cimg.getHeight();

		ByteArrayOutputStream bos = new ByteArrayOutputStream();
		DataOutputStream dos = new DataOutputStream(bos);

		dos.write("8BPS".getBytes());
		dos.writeShort(1);
		dos.write(new byte[6]); // reserved 6bytes

		dos.writeShort(4); // argb

		dos.writeInt(height);
		dos.writeInt(width);

		int depth = 8;
		dos.writeShort(depth);

		dos.writeShort(3); // ColorMode=RGB(3)

		dos.writeInt(0); // カラーモードセクションなし
		dos.writeInt(0); // リソースセクションなし

		// レイヤーマスクセクション
		byte[] layerMaskSection = createLayerMaskSection(layerDatas);
		dos.writeInt(layerMaskSection.length);
		dos.write(layerMaskSection);

		// 画像セクション
		byte[] pictureDatas = createPictureSection(cimg, width, height);
		dos.write(pictureDatas);

		return bos.toByteArray();
	}

	/**
	 * レイヤーマスクセクションを作成する
	 * @param layerDatas
	 * @return
	 * @throws IOException
	 */
	private static byte[] createLayerMaskSection(Collection<LayerData> layerDatas) throws IOException {
		ByteArrayOutputStream bos = new ByteArrayOutputStream();
		DataOutputStream dos = new DataOutputStream(bos);

		byte[] layerData = createLayerData(layerDatas);
		dos.writeInt(layerData.length);
		dos.write(layerData);

		return bos.toByteArray();
	}

	/**
	 * レイヤーデータの作成
	 * @param layerDatas
	 * @return
	 * @throws IOException
	 */
	private static byte[] createLayerData(Collection<LayerData> layerDatas) throws IOException {
		ByteArrayOutputStream bos = new ByteArrayOutputStream();
		DataOutputStream dos = new DataOutputStream(bos);

		int numOfLayers = layerDatas.size();
		dos.writeShort(numOfLayers); // non pre-multiplied

		short[] channels = { -1, 0, 1, 2 }; // ALPHA, RED, GREEN, BLUE

		Map<LayerChannelPair, byte[]> channelDataMap = new HashMap<LayerChannelPair, byte[]>();
		for (LayerData layerData : layerDatas) {
			String layerName = layerData.getLayerName();
			BufferedImage image = layerData.getImage();
			int width = image.getWidth();
			int height = image.getHeight();

			dos.writeInt(0); // top
			dos.writeInt(0); // left
			dos.writeInt(height); // bottom
			dos.writeInt(width); // right

			dos.writeShort(channels.length);

			byte[][] channelsData = createChannels(image);

			for (int channel = 0; channel < channels.length; channel++) {
				byte[] channelData = channelsData[channel];
				byte[] outChannelData;

				if (useRLECompression) {
					// RLE圧縮
					// 行ごとにRLE圧縮する
					int bufsiz = 0;
					List<byte[]> rleRows = new ArrayList<byte[]>();
					for (int y = 0; y < height; y++) {
						byte[] rleRow = compressRLE(channelData, y * width, width);
						rleRows.add(rleRow);
						bufsiz += 2 + rleRow.length;
					}

					ByteBuffer outbuf = ByteBuffer.allocate(bufsiz);

					// 行ごとの圧縮サイズを格納
					for (byte[] rleRow : rleRows) {
						outbuf.putShort((short) rleRow.length);
					}
					// 行ごとに圧縮後データの格納
					for (byte[] rleRow : rleRows) {
						outbuf.put(rleRow);
					}

					outChannelData = outbuf.array();

				} else {
					// RAW (圧縮なし)
					outChannelData = channelData;
				}

				// チャネルID (-1: alpha, 0: red, 1:green, 2:blue)
				dos.writeShort(channels[channel]);

				// チャネルのデータサイズ
				dos.writeInt(2 + outChannelData.length);

				channelDataMap.put(new LayerChannelPair(layerData, channel), outChannelData);
			}

			dos.write("8BIM".getBytes());
			dos.write("norm".getBytes());

			dos.write((byte) 255); // opacity
			dos.write((byte) 0); // clipping
			dos.write((byte) 0); // protection
			dos.write((byte) 0); // filler

			byte[] layerMaskData = createLayerMaskData();
			byte[] layerBlendingData = createLayerBlendingData();
			byte[] layerNameData = createLayerName(layerName);
			int lenOfAdditional = layerMaskData.length + layerBlendingData.length + layerNameData.length;

			dos.writeInt(lenOfAdditional);
			dos.write(layerMaskData);
			dos.write(layerBlendingData);
			dos.write(layerNameData);
		}

		for (LayerData layerData : layerDatas) {
			for (int channel = 0; channel < channels.length; channel++) {
				byte[] outChannelData = channelDataMap.get(new LayerChannelPair(layerData, channel));
				assert outChannelData != null;

				dos.writeShort(useRLECompression ? 1 : 0); // 0:RAW 1:RLE 2..zip
				dos.write(outChannelData);
			}
		}

		return bos.toByteArray();
	}

	/**
	 * 空のレイヤーマスクデータ作成
	 * @return
	 * @throws IOException
	 */
	private static byte[] createLayerMaskData() throws IOException {
		ByteArrayOutputStream bos = new ByteArrayOutputStream();
		DataOutputStream dos = new DataOutputStream(bos);
		dos.writeInt(0);
		return bos.toByteArray();
	}

	/**
	 * 空のレイヤーブレンダーデータの作成
	 * @return
	 * @throws IOException
	 */
	private static byte[] createLayerBlendingData() throws IOException {
		ByteArrayOutputStream bos = new ByteArrayOutputStream();
		DataOutputStream dos = new DataOutputStream(bos);
		dos.writeInt(0);
		return bos.toByteArray();
	}

	/**
	 * レイヤー名の作成
	 * @param layerName
	 * @return
	 * @throws IOException
	 */
	private static byte[] createLayerName(String layerName) throws IOException {
		byte[] nameBuf = layerName.getBytes("UTF-8");
		int layerNameSize = 1 + nameBuf.length; // PASCAL文字列長 (16の倍数サイズ)
		int blockSize = (layerNameSize / 4) * 4 + ((layerNameSize % 4 > 0) ? 4 : 0);
		int paddingSize = blockSize - layerNameSize;

		ByteArrayOutputStream bos = new ByteArrayOutputStream();
		DataOutputStream dos = new DataOutputStream(bos);
		dos.write((byte) nameBuf.length);
		dos.write(nameBuf);
		dos.write(new byte[paddingSize]);
		return bos.toByteArray();
	}

	/**
	 * 32ビットARGB形式のBuffeedImageを受け取り、
	 * ARGBのbyte[][]配列に変換して返す。
	 * @param img イメージ
	 * @return チャネル別配列
	 */
	private static byte[][] createChannels(BufferedImage img) {
		WritableRaster raster = img.getRaster();
		DataBufferInt buffer = (DataBufferInt) raster.getDataBuffer();
		int[] pixels = buffer.getData();

		int width = img.getWidth();
		int height = img.getHeight();
		int mx = width * height;
		byte[][] channels = new byte[4][mx];
		for (int idx = 0; idx < mx; idx++) {
			int argb = pixels[idx];

			int alpha = (argb >> 24) & 0xff;
			int red = (argb >> 16) & 0xff;
			int green = (argb >> 8) & 0xff;
			int blue = argb & 0xff;

			channels[0][idx] = (byte) alpha;
			channels[1][idx] = (byte) red;
			channels[2][idx] = (byte) green;
			channels[3][idx] = (byte) blue;
		}

		return channels;
	}

	/**
	 * レイヤーコレクションを重ねて1つの画像にして返す
	 * @param layerDatas レイヤーコレクション
	 * @return 重ね合わせた画像
	 */
	private static BufferedImage createCompositeImage(Collection<LayerData> layerDatas) {
		int width = 0;
		int height = 0;
		for (LayerData layerData : layerDatas) {
			BufferedImage img = layerData.getImage();
			int w = img.getWidth();
			int h = img.getHeight();
			if (w > width) {
				width = w;
			}
			if (h > height) {
				height = h;
			}
		}

		BufferedImage cimg = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);

		Graphics2D g = cimg.createGraphics();
		try {
			if (isUseRenderingHints()) {
				// リンダリングヒントを使う
				g.setRenderingHint(
						RenderingHints.KEY_ALPHA_INTERPOLATION,
						RenderingHints.VALUE_ALPHA_INTERPOLATION_QUALITY);
				g.setRenderingHint(
						RenderingHints.KEY_COLOR_RENDERING,
						RenderingHints.VALUE_COLOR_RENDER_QUALITY);
				g.setRenderingHint(
						RenderingHints.KEY_RENDERING,
						RenderingHints.VALUE_RENDER_QUALITY);
			}
			for (LayerData layerData : layerDatas) {
				BufferedImage img = layerData.getImage();
				int w = img.getWidth();
				int h = img.getHeight();
				g.drawImage(img, 0, 0, w, h, 0, 0, w, h, null);
			}
		} finally {
			g.dispose();
		}
		return cimg;
	}

	/**
	 * ARGB画像をRLE圧縮されたピクチャーセクションデータに変換する
	 * @param img 画像
	 * @param width 幅、画像とPSDヘッダと一致していること
	 * @param height 高さ、画像とPSDヘッダと一致していること
	 * @return RLE圧縮されたRGBA順チャンネルをつなげたデータ
	 */
	private static byte[] createPictureSection(BufferedImage img, int width, int height) {
		byte[][] channels = createChannels(img);

		assert width == img.getWidth();
		assert height == img.getHeight();

		int[] channelMap = { 1, 2, 3, 0 }; // R, G, B, Aにマップ

		ByteBuffer channelData;
		if (useRLECompression) {
			// RLE圧縮とサイズの計算
			int bufsiz = 2;
			List<byte[]> rows = new ArrayList<byte[]>();
			for (int channel = 0; channel < channels.length; channel++) {
				byte[] pixels = channels[channelMap[channel]];
				for (int y = 0; y < height; y++) {
					byte[] row = compressRLE(pixels, y * width, width);
					rows.add(row);
					bufsiz += 2 + row.length; // ラインごとのバイト数保存(16bit)とラインデータ分を加算
				}
			}

			// RLE圧縮済みバッファ作成
			channelData = ByteBuffer.allocate(bufsiz);
			channelData.order(ByteOrder.BIG_ENDIAN);

			channelData.putShort((short) 1); // RLE圧縮

			// 各チャネルの各行ごとのデータ
			for (byte[] row : rows) {
				channelData.putShort((short) row.length);
			}
			for (byte[] row : rows) {
				channelData.put(row);
			}

		} else {
			// RAWサイズの計算
			int bufsiz = 2;
			for (int channel = 0; channel < channels.length; channel++) {
				byte[] pixels = channels[channelMap[channel]];
				bufsiz += pixels.length;
			}

			// RLE圧縮済みバッファ作成
			channelData = ByteBuffer.allocate(bufsiz);
			channelData.order(ByteOrder.BIG_ENDIAN);

			channelData.putShort((short) 0); // RAW

			for (int channel = 0; channel < channels.length; channel++) {
				byte[] pixels = channels[channelMap[channel]];
				channelData.put(pixels);
			}
		}

		return channelData.array();
	}

	/**
	 * バイト配列をRLE圧縮して返す
	 *  http://www.snap-tck.com/room03/c02/comp/comp02.html
	 * @param data 圧縮するバイト配列
	 * @param offset 開始位置
	 * @param length 長さ
	 * @return RLE圧縮結果
	 */
	public static byte[] compressRLE(byte[] data, int offset, int length) {
		ByteBuffer outbuf = ByteBuffer.allocate(length * 2); // ワーストケース
		ByteBuffer buf = ByteBuffer.wrap(data, offset, length);
		while (buf.hasRemaining()) {
			int ch = buf.get();
			// 不連続数を数える
			int count = 0;
			buf.mark();
			int prev = ch;
			while (buf.hasRemaining() && count < 128) {
				int ch2 = buf.get();
				if (prev == ch2) {
					break;
				}
				count++;
				prev = ch2;
				if (!buf.hasRemaining() && count < 128) {
					// 終端に達した場合は終端も不連続数と数える
					count++;
					break;
				}
			}
			buf.reset();

			if (count > 0) {
				// 不連続数がある場合
				outbuf.put((byte) (count - 1));
				outbuf.put((byte) ch);
				while (--count > 0) {
					ch = buf.get();
					outbuf.put((byte) ch);
				}

			} else {
				// 連続数を数える
				prev = ch;
				count = 1;
				while (buf.hasRemaining() && count < 128) {
					ch = buf.get();
					if (prev != ch) {
						buf.reset();
						break;
					}
					count++;
					buf.mark();
				}
				outbuf.put((byte) (-count + 1));
				outbuf.put((byte) prev);
			}
		}

		outbuf.flip();
		int limit = outbuf.limit();
		byte[] array = outbuf.array();
		byte[] result = new byte[limit];
		System.arraycopy(array, 0, result, 0, limit);
		return result;
	}
}
