JavaでJVM言語を作ってみる(6) - 処理系の実装
これまでの実験のまとめとして、最後に処理系を作ってみました。
言語仕様の概略は以下の通り。
プログラミング言語 (無名) ※便宜上Langと呼びます
主クラス。
基底クラスLangImplementationを作り、CompilerとInterpreterにそれぞれ派生させています。
- abstract class LangImplementation
- final class LangCompiler extends LangImplementation
- final class LangInterpreter extends LangImplementation
- interface LangImplementationVisitor
- final class LangCompilerVisitor implements LangImplementationVisitor
- final class LangInterpreterVisitor implements LangImplementationVisitor
// 1 + 2 iconst1 iconst2 iadd
となるところを、インタプリタの場合はそのままStackを使って
// ※厳密なJavaコードではありません // 1 + 2 stack.push(1); stack.push(2); a = stack.pop(); b = stack.pop(); stack.push(a + b);
のようにしています。
- サンプルコード test.lang
# add a = 1 b = 2 c = a + b println c # circumference pi = 3.1416 r = 3.0 println 2 * pi * r # print format fmt = "%04d-%02d-%02d%n" printf fmt 2012 2 25
$ langc test.lang $ java Lang 3 18.8496 2012-02-25 $ langi test.lang 3 18.8496 2012-02-25 $
以下、主なソースコードです。
- lang.jjt
- LangCompiler
- LangCompilerVisitor
- LangInterpreter
- LangInterpreterVisitor
lang.jjt
options { STATIC=false; MULTI=true; VISITOR=true; NODE_DEFAULT_VOID=true; NODE_PREFIX="Lang"; } PARSER_BEGIN(LangImplementation) package lang; /** Lang Implementation. */ @SuppressWarnings("all") public abstract class LangImplementation { } PARSER_END(LangImplementation) SKIP : { " " | "\t" } SPECIAL_TOKEN : { <INLINE_COMMENT: "#" (~["\n", "\r"])* ("\n" | "\r" | "\r\n")?> } TOKEN : { < EOL: "\r\n" | "\r" | "\n" > } TOKEN : { < F_PRINT: ( "print" | "println" ) > |< F_PRINTF: "printf" > } TOKEN : { < FLOAT_LITERAL: ( (["1"-"9"](["0"-"9"])*) | "0" ) "." (["0"-"9"])* > |< INTEGER_LITERAL: ( (["1"-"9"](["0"-"9"])*) | "0" ) > |< STRING_LITERAL: "\"" (~["\"","\\","\n","\r"])* "\"" > } TOKEN : { < IDENTIFIER: [ "A"-"Z","a"-"z" ]([ "A"-"Z","a"-"z","0"-"9" ])* > } TOKEN : { < OP_ASSIGN: "=" > } TOKEN : { < OP_ADD: ["+","-"] > } TOKEN : { < OP_MUL: ["*","/","%"] > } LangStart Start() #Start : {} { ( <EOL> | ( ( PrintExpression() | PrintfExpression() | AssignExpression() ) EndOfStatement() ) )* { return jjtThis; } } void EndOfStatement() : {} { ( <EOL> | <EOF> ) } void PrintExpression() : { Token t; } { t=<F_PRINT> AdditiveExpression() #FnPrint(1) { jjtn001.jjtSetValue(t.image); } } void PrintfExpression() : { Token t; } { t=<F_PRINTF> ( ( String() | Identifier() ) ( AdditiveExpression() )* ) #FnPrintf(>2) { jjtn001.jjtSetValue(t.image); } } void AssignExpression() : { Token t; } { Identifier() ( t=<OP_ASSIGN> AdditiveExpression() #OpAssign(2) { jjtn001.jjtSetValue(t.image); } ) } void AdditiveExpression() : { Token t; } { MultiplicativeExpression() ( t=<OP_ADD> MultiplicativeExpression() #OpAdd(2) { jjtn001.jjtSetValue(t.image); } )* } void MultiplicativeExpression() : { Token t; } { UnaryExpression() ( t=<OP_MUL> UnaryExpression() #OpMul(2) { jjtn001.jjtSetValue(t.image); } )* } void UnaryExpression() : {} { "(" AdditiveExpression() ")" | Identifier() | Value() } void Identifier() #Identifier : { Token t; } { t=<IDENTIFIER> { jjtThis.jjtSetValue(t.image); } } void Value() : {} { Float() | Integer() | String() } void Float() #Float : { Token t; } { t=<FLOAT_LITERAL> { jjtThis.jjtSetValue(t.image); } } void Integer() #Integer : { Token t; } { t=<INTEGER_LITERAL> { jjtThis.jjtSetValue(t.image); } } void String() #String : { Token t; } { t=<STRING_LITERAL> { jjtThis.jjtSetValue(t.image); } }
LangCompiler
package lang; import java.io.*; /** Lang Compiler. */ public final class LangCompiler extends LangImplementation { public LangCompiler(Reader stream) { super(stream); } static void printUsage() { System.out.println("usage: langc script-file"); } public static void main(String[] args) { if (args.length == 0) { printUsage(); return; } try { Reader r = new FileReader(args[0]); try { LangCompiler t = new LangCompiler(r); t.Start().jjtAccept(new LangCompilerVisitor(), null); } finally { r.close(); } } catch (LangError ex) { System.err.println("compile error: " + ex.getMessage()); } catch (Exception ex) { ex.printStackTrace(); } } }
LangCompilerVisitor
package lang; import static org.apache.bcel.Constants.*; import java.io.*; import java.util.*; import org.apache.bcel.*; import org.apache.bcel.classfile.*; import org.apache.bcel.generic.*; public final class LangCompilerVisitor implements LangImplementationVisitor { static final String CLASS_NAME = "Lang"; ClassGen cg; InstructionList il; InstructionFactory factory; List<String> varList; LinkedHashMap<String, Type> varTypeMap; LangCompilerVisitor() { this.il = new InstructionList(); this.varList = new ArrayList<String>(); this.varTypeMap = new LinkedHashMap<String, Type>(); } private static JavaClass getBaseClass() { try { return Repository.lookupClass(EmptyClass.class); } catch (ClassNotFoundException ex) { throw new RuntimeException(ex); } } static class EmptyClass { // } static Object toRawObject(SimpleNode node) { final String s = String.valueOf(node.jjtGetValue()); if (node instanceof LangFloat) { return Float.valueOf(s); } if (node instanceof LangInteger) { return Integer.valueOf(s); } if (node instanceof LangString) { return s.substring(1, s.length() - 1); } throw new LangError("toRawObject: " + node); } Type getType(Node node) { if (node instanceof SimpleNode) { if ((node instanceof LangOpAdd) || (node instanceof LangOpMul)) { return getType(node.jjtGetChild(node.jjtGetNumChildren() - 1)); } if (node instanceof LangIdentifier) { return varTypeMap.get(((SimpleNode)node).jjtGetValue()); } if (node instanceof LangFloat) { return Type.FLOAT; } if (node instanceof LangInteger) { return Type.INT; } if (node instanceof LangString) { return Type.STRING; } } throw new LangError("getType:" + node); } Object calculate(SimpleNode node, Object data) { final String op = String.valueOf(node.jjtGetValue()); Node c1 = node.jjtGetChild(0); Node c2 = node.jjtGetChild(1); Type t1 = getType(c1); Type t2 = getType(c2); Type t; if (t1 == t2) { data = node.childrenAccept(this, data); t = getType(node); } else { data = c1.jjtAccept(this, data); if (t1 == Type.INT) { il.append(factory.createCast(t1, Type.FLOAT)); } data = c2.jjtAccept(this, data); if (t2 == Type.INT) { il.append(factory.createCast(t2, Type.FLOAT)); } t = Type.FLOAT; } il.append(InstructionFactory.createBinaryOperation(op, t)); return data; } void onBegin() { cg = new ClassGen(getBaseClass()); cg.setClassName(CLASS_NAME); factory = new InstructionFactory(cg); } void onEnd() { il.append(InstructionConstants.RETURN); MethodGen mg = new MethodGen(ACC_PUBLIC | ACC_STATIC, Type.VOID, new Type[]{new ArrayType(Type.STRING, 1)}, new String[]{"args"}, "main", CLASS_NAME, il, cg.getConstantPool()); mg.setMaxStack(); mg.setMaxLocals(); cg.addMethod(mg.getMethod()); JavaClass c = cg.getJavaClass(); c.setSourceFileName(CLASS_NAME + ".java"); try { c.dump(CLASS_NAME + ".class"); } catch (IOException ex) { throw new RuntimeException("", ex); } } @Override public Object visit(SimpleNode node, Object data) { throw new AssertionError("compile error: unexpected token " + node); } @Override public Object visit(LangStart node, Object data) { onBegin(); node.jjtSetValue(""); data = node.childrenAccept(this, data); onEnd(); return data; } @Override public Object visit(LangFnPrint node, Object data) { final String fName = String.valueOf(node.jjtGetValue()); il.append(factory.createGetStatic(System.class.getName(), "out", Type.getType(PrintStream.class))); data = node.childrenAccept(this, data); il.append(factory.createInvoke(java.io.PrintStream.class.getName(), fName, Type.VOID, new Type[]{getType(node.jjtGetChild(0))}, INVOKEVIRTUAL)); return data; } @Override public Object visit(LangFnPrintf node, Object data) { il.append(factory.createGetStatic(System.class.getName(), "out", Type.getType(PrintStream.class))); final SimpleNode child0 = (SimpleNode)node.jjtGetChild(0); if (child0 instanceof LangString) { il.append(factory.createConstant(toRawObject(child0))); } else if (child0 instanceof LangIdentifier && getType(child0) == Type.STRING) { data = visit((LangIdentifier)child0, data); } else { throw new LangError("printf 1st argument must be string"); } Object[] a = new Object[node.jjtGetNumChildren() - 1]; for (int i = 0; i < a.length; i++) { SimpleNode child = (SimpleNode)node.jjtGetChild(i + 1); a[i] = toRawObject(child); } appendObjectArrayLoad(a); il.append(factory.createInvoke(java.io.PrintStream.class.getName(), "printf", Type.getType(PrintStream.class), new Type[]{Type.STRING, new ArrayType(Type.OBJECT, 1)}, INVOKEVIRTUAL)); il.append(InstructionFactory.createPop(1)); return data; } private void appendObjectArrayLoad(Object[] a) { il.append(factory.createConstant(a.length)); il.append(factory.createNewArray(Type.OBJECT, (short)1)); for (int i = 0; i < a.length; i++) { il.append(InstructionFactory.createDup(1)); il.append(factory.createConstant(i)); il.append(factory.createConstant(a[i])); if (a[i].getClass() == Float.class) { il.append(factory.createInvoke(Float.class.getName(), "valueOf", Type.getType(Float.class), new Type[]{Type.FLOAT}, INVOKESTATIC)); } else if (a[i].getClass() == Integer.class) { il.append(factory.createInvoke(Integer.class.getName(), "valueOf", Type.getType(Integer.class), new Type[]{Type.INT}, INVOKESTATIC)); } il.append(InstructionFactory.createArrayStore(Type.OBJECT)); } } @Override public Object visit(LangOpAssign node, Object data) { LangIdentifier identifier = (LangIdentifier)node.jjtGetChild(0); final String varName = String.valueOf(identifier.jjtGetValue()); data = node.jjtGetChild(1).jjtAccept(this, data); SimpleNode n = (SimpleNode)node.jjtGetChild(1); Type type = getType(n); if (!varList.contains(varName)) { varList.add(varName); varTypeMap.put(varName, type); } final int index = varList.indexOf(varName); il.append(InstructionFactory.createStore(type, index)); return data; } @Override public Object visit(LangOpAdd node, Object data) { final String op = String.valueOf(node.jjtGetValue()); if (op.equals("+")) { Node c1 = node.jjtGetChild(0); Node c2 = node.jjtGetChild(1); Type t1 = getType(c1); Type t2 = getType(c2); if (t1 == Type.STRING || t2 == Type.STRING) { throw new LangError("+ for string not supported"); } } return calculate(node, data); } @Override public Object visit(LangOpMul node, Object data) { return calculate(node, data); } @Override public Object visit(LangIdentifier node, Object data) { final String varName = String.valueOf(node.jjtGetValue()); final int index = varList.indexOf(varName); il.append(InstructionFactory.createLoad(getType(node), index)); return data; } @Override public Object visit(LangFloat node, Object data) { final Float f = Float.valueOf(String.valueOf(node.jjtGetValue())); il.append(factory.createConstant(f)); return data; } @Override public Object visit(LangInteger node, Object data) { final Integer i = Integer.valueOf(String.valueOf(node.jjtGetValue())); il.append(factory.createConstant(i)); return data; } @Override public Object visit(LangString node, Object data) { il.append(factory.createConstant(toRawObject(node))); return data; } }
LangInterpreter
package lang; import java.io.*; /** Lang Interpreter. */ public final class LangInterpreter extends LangImplementation { public LangInterpreter(Reader stream) { super(stream); } static void printUsage() { System.out.println("usage: langi [-d] script-file"); } public static void main(String[] args) { if (args.length == 0) { printUsage(); return; } final String file; final boolean debug; if (args[0].equals("-d")) { if (args.length < 2) { printUsage(); return; } file = args[1]; debug = true; } else { file = args[0]; debug = false; } try { Reader r = new FileReader(file); try { LangInterpreter t = new LangInterpreter(r); t.Start().jjtAccept(new LangInterpreterVisitor(debug), null); } finally { r.close(); } } catch (LangError ex) { System.err.println("error: " + ex.getMessage()); } catch (ParseException ex) { System.err.println("error: " + ex.getMessage()); } catch (IOException ex) { throw new RuntimeException(ex); } } }
LangInterpreterVisitor
package lang; import java.util.*; public class LangInterpreterVisitor implements LangImplementationVisitor { private Stack<Object> stack; private Map<String, Object> varMap; private boolean debug; private int indent; LangInterpreterVisitor() { this(false); } LangInterpreterVisitor(boolean debug) { this.stack = new Stack<Object>(); this.varMap = new HashMap<String, Object>(); this.debug = debug; this.indent = 0; } void onBegin() { if (debug) { System.out.println("*** Lang Interpreter START ***"); } } void onEnd() { if (debug) { System.out.printf("final values = %s%n", stack); System.out.println("*** Lang Interpreter E N D ***"); } } Object evaluate(Object o) { if (o instanceof Float) { return o; } if (o instanceof Integer) { return o; } if (o instanceof String) { return o; } if (o instanceof LangIdentifier) { final String id = String.valueOf(((LangIdentifier)o).jjtGetValue()); if (!varMap.containsKey(id)) { throw new LangError("unassined variable: " + id); } return varMap.get(id); } throw new LangError("can't evaluate: " + o.getClass()); } static Integer calculate(String op, int a, int b) { switch (op.length() == 1 ? op.charAt(0) : ' ') { case '+': return a + b; case '-': return a - b; case '*': return a * b; case '/': return a / b; case '%': return a % b; } throw new AssertionError(); } static Float calculate(String op, float a, float b) { switch (op.length() == 1 ? op.charAt(0) : ' ') { case '+': return a + b; case '-': return a - b; case '*': return a * b; case '/': return a / b; case '%': return a % b; } throw new AssertionError(); } static String p(SimpleNode node) { return String.format("%s %s", node, node.jjtGetValue()); } void printCodeS(SimpleNode node) { printCode(node, ">>> "); } void printCodeE(SimpleNode node) { printCode(node, "<<< "); } void printCode(SimpleNode node, String prefix) { if (!debug) { return; } System.out.println(" ### stack = " + stack); System.out.println(" ### vars = " + varMap); char[] a = new char[indent * 2]; Arrays.fill(a, ' '); System.out.printf("%s%s%s%n", prefix, String.valueOf(a), p(node)); } @Override public Object visit(SimpleNode node, Object data) { throw new AssertionError("unexpected token " + node); } @Override public Object visit(LangStart node, Object data) { onBegin(); node.jjtSetValue(""); printCodeS(node); ++indent; data = node.childrenAccept(this, data); --indent; onEnd(); return data; } @Override public Object visit(LangFnPrint node, Object data) { printCodeS(node); ++indent; data = node.childrenAccept(this, data); final String fname = String.valueOf(node.jjtGetValue()); final Object o1 = stack.pop(); final Object o = evaluate(o1); if (fname.equals("print")) { System.out.print(o); } else if (fname.equals("println")) { System.out.println(o); } --indent; printCodeE(node); return data; } @Override public Object visit(LangFnPrintf node, Object data) { printCodeS(node); ++indent; if (node.jjtGetNumChildren() > 1) { data = node.childrenAccept(this, data); final int n = node.jjtGetNumChildren(); final Object[] args = new Object[n - 1]; for (int i = args.length - 1; i >= 0; i--) { args[i] = stack.pop(); } final String format = String.valueOf(stack.pop()); System.out.printf(format, args); } else { throw new LangError("printf requires at least 2 arguments"); } --indent; printCodeE(node); return data; } @Override public Object visit(LangOpAssign node, Object data) { printCodeS(node); LangIdentifier identifier = (LangIdentifier)node.jjtGetChild(0); final String varName = String.valueOf(identifier.jjtGetValue()); ++indent; // evaluate only RHS data = node.jjtGetChild(1).jjtAccept(this, data); final Object a = evaluate(stack.pop()); varMap.put(varName, a); --indent; printCodeE(node); return data; } @Override public Object visit(LangOpAdd node, Object data) { printCodeS(node); final String op = String.valueOf(node.jjtGetValue()); ++indent; data = node.childrenAccept(this, data); final Object o2 = stack.pop(); final Object o1 = stack.pop(); final Object r; if (o1 instanceof Number && o2 instanceof Number) { final Number n1 = (Number)o1; final Number n2 = (Number)o2; if (n1 instanceof Integer && n2 instanceof Integer) { r = calculate(op, n1.intValue(), n2.intValue()); } else { r = calculate(op, n1.floatValue(), n2.floatValue()); } } else if (o1 instanceof String) { if (op.equals("+")) { r = ((String)o1) + o2; } else { throw new LangError(String.format("%s %s %s", o1, op, o2)); } } else { throw new LangError(String.format("%s %s %s", o1, op, o2)); } stack.push(r); --indent; printCodeE(node); return data; } @Override public Object visit(LangOpMul node, Object data) { printCodeS(node); final String op = String.valueOf(node.jjtGetValue()); ++indent; data = node.childrenAccept(this, data); final Object o2 = stack.pop(); final Object o1 = stack.pop(); final Object r; if (o1 instanceof Number && o2 instanceof Number) { Number n1 = (Number)o1; Number n2 = (Number)o2; if (n1 instanceof Integer && n2 instanceof Integer) { r = calculate(op, n1.intValue(), n2.intValue()); } else { r = calculate(op, n1.floatValue(), n2.floatValue()); } } else { throw new LangError(String.format("%s %s %s", o1, op, o2)); } stack.push(r); --indent; printCodeE(node); return data; } @Override public Object visit(LangIdentifier node, Object data) { printCodeS(node); stack.push(varMap.get(node.jjtGetValue())); printCodeE(node); return data; } @Override public Object visit(LangFloat node, Object data) { printCodeS(node); stack.push(Float.valueOf(String.valueOf(node.jjtGetValue()))); printCodeE(node); return data; } @Override public Object visit(LangInteger node, Object data) { printCodeS(node); stack.push(Integer.valueOf(String.valueOf(node.jjtGetValue()))); printCodeE(node); return data; } @Override public Object visit(LangString node, Object data) { printCodeS(node); final String s = (String)node.jjtGetValue(); stack.push(s.substring(1, s.length() - 1)); printCodeE(node); return data; } }
こんな感じで、それっぽいものはできました。
本格的なものにするにはまだまだ先は長いですが、地道にやってみます。
今回はこれで終わりです。