argius note

プログラミング関連

JavaでJVM言語を作ってみる(6) - 処理系の実装

これまでの実験のまとめとして、最後に処理系を作ってみました。


言語仕様の概略は以下の通り。

プログラミング言語 (無名) ※便宜上Langと呼びます

  • mainメソッドのみ
  • 改行までが1ステートメント
  • "#"はインラインコメント
  • 予約語
    • "print","println","printf"(System.outに対応)
  • 演算子
  • 型:int, float, String
  • マルチバイト非対応
  • 実行コマンド-スクリプトファイルを1つ引数に指定する

主クラス。
基底クラスLangImplementationを作り、CompilerとInterpreterにそれぞれ派生させています。

  • abstract class LangImplementation
    • final class LangCompiler extends LangImplementation
    • final class LangInterpreter extends LangImplementation
  • interface LangImplementationVisitor
    • final class LangCompilerVisitor implements LangImplementationVisitor
    • final class LangInterpreterVisitor implements LangImplementationVisitor


インタプリタについて。
コンパイラでは、

// 1 + 2
iconst1
iconst2
iadd

となるところを、インタプリタの場合はそのままStackを使って

// ※厳密なJavaコードではありません
// 1 + 2
stack.push(1);
stack.push(2);
a = stack.pop();
b = stack.pop();
stack.push(a + b);

のようにしています。

  • サンプルコード test.lang
# add
a = 1
b = 2
c = a + b
println c

# circumference
pi = 3.1416
r = 3.0
println 2 * pi * r

# print format
fmt = "%04d-%02d-%02d%n"
printf fmt 2012 2 25
$ langc test.lang 
$ java Lang
3
18.8496
2012-02-25
$ langi test.lang
3
18.8496
2012-02-25
$

以下、主なソースコードです。

  • lang.jjt
  • LangCompiler
  • LangCompilerVisitor
  • LangInterpreter
  • LangInterpreterVisitor

lang.jjt

options {
  STATIC=false;
  MULTI=true;
  VISITOR=true;
  NODE_DEFAULT_VOID=true;
  NODE_PREFIX="Lang";
}

PARSER_BEGIN(LangImplementation)
package lang;

/** Lang Implementation. */
@SuppressWarnings("all")
public abstract class LangImplementation { }

PARSER_END(LangImplementation)


SKIP : { " " | "\t" }

SPECIAL_TOKEN : { <INLINE_COMMENT: "#" (~["\n", "\r"])* ("\n" | "\r" | "\r\n")?> }

TOKEN : { < EOL: "\r\n" | "\r" | "\n" > }

TOKEN : {
 < F_PRINT: ( "print" | "println" ) >
|< F_PRINTF: "printf" >
}

TOKEN : {
 < FLOAT_LITERAL: ( (["1"-"9"](["0"-"9"])*) | "0" ) "." (["0"-"9"])* >
|< INTEGER_LITERAL: ( (["1"-"9"](["0"-"9"])*) | "0" ) >
|< STRING_LITERAL: "\"" (~["\"","\\","\n","\r"])* "\"" >
}

TOKEN : { < IDENTIFIER: [ "A"-"Z","a"-"z" ]([ "A"-"Z","a"-"z","0"-"9" ])* > }

TOKEN : { < OP_ASSIGN: "=" > }

TOKEN : { < OP_ADD: ["+","-"] > }

TOKEN : { < OP_MUL: ["*","/","%"] > }


LangStart Start() #Start : {}
{ ( <EOL> | ( ( PrintExpression() | PrintfExpression() | AssignExpression() ) EndOfStatement() ) )* { return jjtThis; } }

void EndOfStatement() : {}
{ ( <EOL> | <EOF> ) }

void PrintExpression() : { Token t; }
{ t=<F_PRINT> AdditiveExpression() #FnPrint(1) { jjtn001.jjtSetValue(t.image); } }

void PrintfExpression() : { Token t; }
{ t=<F_PRINTF> ( ( String() | Identifier() ) ( AdditiveExpression() )* ) #FnPrintf(>2) { jjtn001.jjtSetValue(t.image); } }

void AssignExpression() : { Token t; }
{ Identifier() ( t=<OP_ASSIGN> AdditiveExpression() #OpAssign(2) { jjtn001.jjtSetValue(t.image); } ) }

void AdditiveExpression() : { Token t; }
{ MultiplicativeExpression() ( t=<OP_ADD> MultiplicativeExpression() #OpAdd(2) { jjtn001.jjtSetValue(t.image); } )* } 

void MultiplicativeExpression() : { Token t; }
{ UnaryExpression() ( t=<OP_MUL> UnaryExpression() #OpMul(2) { jjtn001.jjtSetValue(t.image); } )* }

void UnaryExpression() : {}
{ "(" AdditiveExpression() ")" | Identifier() | Value() }

void Identifier() #Identifier : { Token t; }
{ t=<IDENTIFIER> { jjtThis.jjtSetValue(t.image); } }

void Value() : {}
{ Float() | Integer() | String() }

void Float() #Float : { Token t; }
{ t=<FLOAT_LITERAL> { jjtThis.jjtSetValue(t.image); } }

void Integer() #Integer : { Token t; }
{ t=<INTEGER_LITERAL> { jjtThis.jjtSetValue(t.image); } }

void String() #String : { Token t; }
{ t=<STRING_LITERAL> { jjtThis.jjtSetValue(t.image); } }

LangCompiler

package lang;

import java.io.*;

/** Lang Compiler. */
public final class LangCompiler extends LangImplementation {

    public LangCompiler(Reader stream) {
        super(stream);
    }

    static void printUsage() {
        System.out.println("usage: langc script-file");
    }

    public static void main(String[] args) {
        if (args.length == 0) {
            printUsage();
            return;
        }
        try {
            Reader r = new FileReader(args[0]);
            try {
                LangCompiler t = new LangCompiler(r);
                t.Start().jjtAccept(new LangCompilerVisitor(), null);
            } finally {
                r.close();
            }
        } catch (LangError ex) {
            System.err.println("compile error: " + ex.getMessage());
        } catch (Exception ex) {
            ex.printStackTrace();
        }
    }

}

LangCompilerVisitor

package lang;

import static org.apache.bcel.Constants.*;

import java.io.*;
import java.util.*;

import org.apache.bcel.*;
import org.apache.bcel.classfile.*;
import org.apache.bcel.generic.*;

public final class LangCompilerVisitor implements LangImplementationVisitor {

    static final String CLASS_NAME = "Lang";

    ClassGen cg;
    InstructionList il;
    InstructionFactory factory;

    List<String> varList;
    LinkedHashMap<String, Type> varTypeMap;

    LangCompilerVisitor() {
        this.il = new InstructionList();
        this.varList = new ArrayList<String>();
        this.varTypeMap = new LinkedHashMap<String, Type>();
    }

    private static JavaClass getBaseClass() {
        try {
            return Repository.lookupClass(EmptyClass.class);
        } catch (ClassNotFoundException ex) {
            throw new RuntimeException(ex);
        }
    }

    static class EmptyClass {
        //
    }

    static Object toRawObject(SimpleNode node) {
        final String s = String.valueOf(node.jjtGetValue());
        if (node instanceof LangFloat) {
            return Float.valueOf(s);
        }
        if (node instanceof LangInteger) {
            return Integer.valueOf(s);
        }
        if (node instanceof LangString) {
            return s.substring(1, s.length() - 1);
        }
        throw new LangError("toRawObject: " + node);
    }

    Type getType(Node node) {
        if (node instanceof SimpleNode) {
            if ((node instanceof LangOpAdd) || (node instanceof LangOpMul)) {
                return getType(node.jjtGetChild(node.jjtGetNumChildren() - 1));
            }
            if (node instanceof LangIdentifier) {
                return varTypeMap.get(((SimpleNode)node).jjtGetValue());
            }
            if (node instanceof LangFloat) {
                return Type.FLOAT;
            }
            if (node instanceof LangInteger) {
                return Type.INT;
            }
            if (node instanceof LangString) {
                return Type.STRING;
            }
        }
        throw new LangError("getType:" + node);
    }

    Object calculate(SimpleNode node, Object data) {
        final String op = String.valueOf(node.jjtGetValue());
        Node c1 = node.jjtGetChild(0);
        Node c2 = node.jjtGetChild(1);
        Type t1 = getType(c1);
        Type t2 = getType(c2);
        Type t;
        if (t1 == t2) {
            data = node.childrenAccept(this, data);
            t = getType(node);
        } else {
            data = c1.jjtAccept(this, data);
            if (t1 == Type.INT) {
                il.append(factory.createCast(t1, Type.FLOAT));
            }
            data = c2.jjtAccept(this, data);
            if (t2 == Type.INT) {
                il.append(factory.createCast(t2, Type.FLOAT));
            }
            t = Type.FLOAT;
        }
        il.append(InstructionFactory.createBinaryOperation(op, t));
        return data;
    }

    void onBegin() {
        cg = new ClassGen(getBaseClass());
        cg.setClassName(CLASS_NAME);
        factory = new InstructionFactory(cg);
    }

    void onEnd() {
        il.append(InstructionConstants.RETURN);
        MethodGen mg = new MethodGen(ACC_PUBLIC | ACC_STATIC,
                                     Type.VOID,
                                     new Type[]{new ArrayType(Type.STRING, 1)},
                                     new String[]{"args"},
                                     "main",
                                     CLASS_NAME,
                                     il,
                                     cg.getConstantPool());
        mg.setMaxStack();
        mg.setMaxLocals();
        cg.addMethod(mg.getMethod());
        JavaClass c = cg.getJavaClass();
        c.setSourceFileName(CLASS_NAME + ".java");
        try {
            c.dump(CLASS_NAME + ".class");
        } catch (IOException ex) {
            throw new RuntimeException("", ex);
        }
    }

    @Override
    public Object visit(SimpleNode node, Object data) {
        throw new AssertionError("compile error: unexpected token " + node);
    }

    @Override
    public Object visit(LangStart node, Object data) {
        onBegin();
        node.jjtSetValue("");
        data = node.childrenAccept(this, data);
        onEnd();
        return data;
    }

    @Override
    public Object visit(LangFnPrint node, Object data) {
        final String fName = String.valueOf(node.jjtGetValue());
        il.append(factory.createGetStatic(System.class.getName(),
                                          "out",
                                          Type.getType(PrintStream.class)));
        data = node.childrenAccept(this, data);
        il.append(factory.createInvoke(java.io.PrintStream.class.getName(),
                                       fName,
                                       Type.VOID,
                                       new Type[]{getType(node.jjtGetChild(0))},
                                       INVOKEVIRTUAL));
        return data;
    }

    @Override
    public Object visit(LangFnPrintf node, Object data) {
        il.append(factory.createGetStatic(System.class.getName(),
                                          "out",
                                          Type.getType(PrintStream.class)));
        final SimpleNode child0 = (SimpleNode)node.jjtGetChild(0);
        if (child0 instanceof LangString) {
            il.append(factory.createConstant(toRawObject(child0)));
        } else if (child0 instanceof LangIdentifier && getType(child0) == Type.STRING) {
            data = visit((LangIdentifier)child0, data);
        } else {
            throw new LangError("printf 1st argument must be string");
        }
        Object[] a = new Object[node.jjtGetNumChildren() - 1];
        for (int i = 0; i < a.length; i++) {
            SimpleNode child = (SimpleNode)node.jjtGetChild(i + 1);
            a[i] = toRawObject(child);
        }
        appendObjectArrayLoad(a);
        il.append(factory.createInvoke(java.io.PrintStream.class.getName(),
                                       "printf",
                                       Type.getType(PrintStream.class),
                                       new Type[]{Type.STRING, new ArrayType(Type.OBJECT, 1)},
                                       INVOKEVIRTUAL));
        il.append(InstructionFactory.createPop(1));
        return data;
    }

    private void appendObjectArrayLoad(Object[] a) {
        il.append(factory.createConstant(a.length));
        il.append(factory.createNewArray(Type.OBJECT, (short)1));
        for (int i = 0; i < a.length; i++) {
            il.append(InstructionFactory.createDup(1));
            il.append(factory.createConstant(i));
            il.append(factory.createConstant(a[i]));
            if (a[i].getClass() == Float.class) {
                il.append(factory.createInvoke(Float.class.getName(),
                                               "valueOf",
                                               Type.getType(Float.class),
                                               new Type[]{Type.FLOAT},
                                               INVOKESTATIC));
            } else if (a[i].getClass() == Integer.class) {
                il.append(factory.createInvoke(Integer.class.getName(),
                                               "valueOf",
                                               Type.getType(Integer.class),
                                               new Type[]{Type.INT},
                                               INVOKESTATIC));
            }
            il.append(InstructionFactory.createArrayStore(Type.OBJECT));
        }
    }

    @Override
    public Object visit(LangOpAssign node, Object data) {
        LangIdentifier identifier = (LangIdentifier)node.jjtGetChild(0);
        final String varName = String.valueOf(identifier.jjtGetValue());
        data = node.jjtGetChild(1).jjtAccept(this, data);
        SimpleNode n = (SimpleNode)node.jjtGetChild(1);
        Type type = getType(n);
        if (!varList.contains(varName)) {
            varList.add(varName);
            varTypeMap.put(varName, type);
        }
        final int index = varList.indexOf(varName);
        il.append(InstructionFactory.createStore(type, index));
        return data;
    }

    @Override
    public Object visit(LangOpAdd node, Object data) {
        final String op = String.valueOf(node.jjtGetValue());
        if (op.equals("+")) {
            Node c1 = node.jjtGetChild(0);
            Node c2 = node.jjtGetChild(1);
            Type t1 = getType(c1);
            Type t2 = getType(c2);
            if (t1 == Type.STRING || t2 == Type.STRING) {
                throw new LangError("+ for string not supported");
            }
        }
        return calculate(node, data);
    }

    @Override
    public Object visit(LangOpMul node, Object data) {
        return calculate(node, data);
    }

    @Override
    public Object visit(LangIdentifier node, Object data) {
        final String varName = String.valueOf(node.jjtGetValue());
        final int index = varList.indexOf(varName);
        il.append(InstructionFactory.createLoad(getType(node), index));
        return data;
    }

    @Override
    public Object visit(LangFloat node, Object data) {
        final Float f = Float.valueOf(String.valueOf(node.jjtGetValue()));
        il.append(factory.createConstant(f));
        return data;
    }

    @Override
    public Object visit(LangInteger node, Object data) {
        final Integer i = Integer.valueOf(String.valueOf(node.jjtGetValue()));
        il.append(factory.createConstant(i));
        return data;
    }

    @Override
    public Object visit(LangString node, Object data) {
        il.append(factory.createConstant(toRawObject(node)));
        return data;
    }

}

LangInterpreter

package lang;

import java.io.*;

/** Lang Interpreter. */
public final class LangInterpreter extends LangImplementation {

    public LangInterpreter(Reader stream) {
        super(stream);
    }

    static void printUsage() {
        System.out.println("usage: langi [-d] script-file");
    }

    public static void main(String[] args) {
        if (args.length == 0) {
            printUsage();
            return;
        }
        final String file;
        final boolean debug;
        if (args[0].equals("-d")) {
            if (args.length < 2) {
                printUsage();
                return;
            }
            file = args[1];
            debug = true;
        } else {
            file = args[0];
            debug = false;
        }
        try {
            Reader r = new FileReader(file);
            try {
                LangInterpreter t = new LangInterpreter(r);
                t.Start().jjtAccept(new LangInterpreterVisitor(debug), null);
            } finally {
                r.close();
            }
        } catch (LangError ex) {
            System.err.println("error: " + ex.getMessage());
        } catch (ParseException ex) {
            System.err.println("error: " + ex.getMessage());
        } catch (IOException ex) {
            throw new RuntimeException(ex);
        }
    }

}

LangInterpreterVisitor

package lang;

import java.util.*;

public class LangInterpreterVisitor implements LangImplementationVisitor {

    private Stack<Object> stack;
    private Map<String, Object> varMap;
    private boolean debug;
    private int indent;

    LangInterpreterVisitor() {
        this(false);
    }

    LangInterpreterVisitor(boolean debug) {
        this.stack = new Stack<Object>();
        this.varMap = new HashMap<String, Object>();
        this.debug = debug;
        this.indent = 0;
    }

    void onBegin() {
        if (debug) {
            System.out.println("*** Lang Interpreter START ***");
        }
    }

    void onEnd() {
        if (debug) {
            System.out.printf("final values = %s%n", stack);
            System.out.println("*** Lang Interpreter E N D ***");
        }
    }

    Object evaluate(Object o) {
        if (o instanceof Float) {
            return o;
        }
        if (o instanceof Integer) {
            return o;
        }
        if (o instanceof String) {
            return o;
        }
        if (o instanceof LangIdentifier) {
            final String id = String.valueOf(((LangIdentifier)o).jjtGetValue());
            if (!varMap.containsKey(id)) {
                throw new LangError("unassined variable: " + id);
            }
            return varMap.get(id);
        }
        throw new LangError("can't evaluate: " + o.getClass());
    }

    static Integer calculate(String op, int a, int b) {
        switch (op.length() == 1 ? op.charAt(0) : ' ') {
            case '+':
                return a + b;
            case '-':
                return a - b;
            case '*':
                return a * b;
            case '/':
                return a / b;
            case '%':
                return a % b;
        }
        throw new AssertionError();
    }

    static Float calculate(String op, float a, float b) {
        switch (op.length() == 1 ? op.charAt(0) : ' ') {
            case '+':
                return a + b;
            case '-':
                return a - b;
            case '*':
                return a * b;
            case '/':
                return a / b;
            case '%':
                return a % b;
        }
        throw new AssertionError();
    }

    static String p(SimpleNode node) {
        return String.format("%s %s", node, node.jjtGetValue());
    }

    void printCodeS(SimpleNode node) {
        printCode(node, ">>> ");
    }

    void printCodeE(SimpleNode node) {
        printCode(node, "<<< ");
    }

    void printCode(SimpleNode node, String prefix) {
        if (!debug) {
            return;
        }
        System.out.println("  ### stack = " + stack);
        System.out.println("  ### vars  = " + varMap);
        char[] a = new char[indent * 2];
        Arrays.fill(a, ' ');
        System.out.printf("%s%s%s%n", prefix, String.valueOf(a), p(node));
    }

    @Override
    public Object visit(SimpleNode node, Object data) {
        throw new AssertionError("unexpected token " + node);
    }

    @Override
    public Object visit(LangStart node, Object data) {
        onBegin();
        node.jjtSetValue("");
        printCodeS(node);
        ++indent;
        data = node.childrenAccept(this, data);
        --indent;
        onEnd();
        return data;
    }

    @Override
    public Object visit(LangFnPrint node, Object data) {
        printCodeS(node);
        ++indent;
        data = node.childrenAccept(this, data);
        final String fname = String.valueOf(node.jjtGetValue());
        final Object o1 = stack.pop();
        final Object o = evaluate(o1);
        if (fname.equals("print")) {
            System.out.print(o);
        } else if (fname.equals("println")) {
            System.out.println(o);
        }
        --indent;
        printCodeE(node);
        return data;
    }

    @Override
    public Object visit(LangFnPrintf node, Object data) {
        printCodeS(node);
        ++indent;
        if (node.jjtGetNumChildren() > 1) {
            data = node.childrenAccept(this, data);
            final int n = node.jjtGetNumChildren();
            final Object[] args = new Object[n - 1];
            for (int i = args.length - 1; i >= 0; i--) {
                args[i] = stack.pop();
            }
            final String format = String.valueOf(stack.pop());
            System.out.printf(format, args);
        } else {
            throw new LangError("printf requires at least 2 arguments");
        }
        --indent;
        printCodeE(node);
        return data;
    }

    @Override
    public Object visit(LangOpAssign node, Object data) {
        printCodeS(node);
        LangIdentifier identifier = (LangIdentifier)node.jjtGetChild(0);
        final String varName = String.valueOf(identifier.jjtGetValue());
        ++indent;
        // evaluate only RHS
        data = node.jjtGetChild(1).jjtAccept(this, data);
        final Object a = evaluate(stack.pop());
        varMap.put(varName, a);
        --indent;
        printCodeE(node);
        return data;
    }

    @Override
    public Object visit(LangOpAdd node, Object data) {
        printCodeS(node);
        final String op = String.valueOf(node.jjtGetValue());
        ++indent;
        data = node.childrenAccept(this, data);
        final Object o2 = stack.pop();
        final Object o1 = stack.pop();
        final Object r;
        if (o1 instanceof Number && o2 instanceof Number) {
            final Number n1 = (Number)o1;
            final Number n2 = (Number)o2;
            if (n1 instanceof Integer && n2 instanceof Integer) {
                r = calculate(op, n1.intValue(), n2.intValue());
            } else {
                r = calculate(op, n1.floatValue(), n2.floatValue());
            }
        } else if (o1 instanceof String) {
            if (op.equals("+")) {
                r = ((String)o1) + o2;
            } else {
                throw new LangError(String.format("%s %s %s", o1, op, o2));
            }
        } else {
            throw new LangError(String.format("%s %s %s", o1, op, o2));
        }
        stack.push(r);
        --indent;
        printCodeE(node);
        return data;
    }

    @Override
    public Object visit(LangOpMul node, Object data) {
        printCodeS(node);
        final String op = String.valueOf(node.jjtGetValue());
        ++indent;
        data = node.childrenAccept(this, data);
        final Object o2 = stack.pop();
        final Object o1 = stack.pop();
        final Object r;
        if (o1 instanceof Number && o2 instanceof Number) {
            Number n1 = (Number)o1;
            Number n2 = (Number)o2;
            if (n1 instanceof Integer && n2 instanceof Integer) {
                r = calculate(op, n1.intValue(), n2.intValue());
            } else {
                r = calculate(op, n1.floatValue(), n2.floatValue());
            }
        } else {
            throw new LangError(String.format("%s %s %s", o1, op, o2));
        }
        stack.push(r);
        --indent;
        printCodeE(node);
        return data;
    }

    @Override
    public Object visit(LangIdentifier node, Object data) {
        printCodeS(node);
        stack.push(varMap.get(node.jjtGetValue()));
        printCodeE(node);
        return data;
    }

    @Override
    public Object visit(LangFloat node, Object data) {
        printCodeS(node);
        stack.push(Float.valueOf(String.valueOf(node.jjtGetValue())));
        printCodeE(node);
        return data;
    }

    @Override
    public Object visit(LangInteger node, Object data) {
        printCodeS(node);
        stack.push(Integer.valueOf(String.valueOf(node.jjtGetValue())));
        printCodeE(node);
        return data;
    }

    @Override
    public Object visit(LangString node, Object data) {
        printCodeS(node);
        final String s = (String)node.jjtGetValue();
        stack.push(s.substring(1, s.length() - 1));
        printCodeE(node);
        return data;
    }

}

こんな感じで、それっぽいものはできました。
本格的なものにするにはまだまだ先は長いですが、地道にやってみます。


今回はこれで終わりです。