feat: 重构编译器以支持函数定义和调用,添加新的字节码以支持函数调用

另外,我很高兴地宣布,fib(40) 递归法 在我的平台, i5-13490f,只需要 6600ms, fib(30) 56ms
这是历史性的一刻!
This commit is contained in:
2026-03-07 00:34:52 +08:00
parent 1fe9ccf7ea
commit 6dbecbbdc0
13 changed files with 477 additions and 127 deletions

View File

@@ -20,9 +20,17 @@ namespace Fig
enum class OpCode : std::uint8_t enum class OpCode : std::uint8_t
{ {
Exit, // 结束运行 Exit, // 结束运行
LoadK, // iABx 模式: R[A] = Constants[Bx] LoadK, // iABx 模式: R[A] = Constants[Bx]
Return, // iA 模式: 返回 R[A] 的值 LoadTrue, // iABC: R[A] = true
LoadFalse, // iABC: R[A] = false
LoadNull, // iABC: R[A] = null
FastCall, // iABC: A: ProtoIdx, B: 函数起始寄存器
Call, // 动态派发 iABC: A: 函数体对象寄存器 B: 函数起始寄存器
Return, // iABC 模式: 返回 R[A] 的值
LoadFn, // 惰性装修, iABx: R[A] = new FunctionObject...
Jmp, // iAsBx: ip += sBx 无条件跳转 Jmp, // iAsBx: ip += sBx 无条件跳转
JmpIfFalse, // iAsBx: 如果 R[A] 为假, ip += sBx JmpIfFalse, // iAsBx: 如果 R[A] 为假, ip += sBx

View File

@@ -1,9 +1,10 @@
#include <Compiler/Compiler.hpp>
#include <Core/Core.hpp>
#include <Deps/Deps.hpp>
#include <Lexer/Lexer.hpp> #include <Lexer/Lexer.hpp>
#include <Parser/Parser.hpp> #include <Parser/Parser.hpp>
#include <Deps/Deps.hpp>
#include <Core/Core.hpp>
#include <SourceManager/SourceManager.hpp> #include <SourceManager/SourceManager.hpp>
#include <Compiler/Compiler.hpp>
#include <iostream> #include <iostream>
#include <print> #include <print>
@@ -24,7 +25,7 @@ int main()
return 1; return 1;
} }
Lexer lexer(manager.GetSource(), fileName); Lexer lexer(manager.GetSource(), fileName);
Parser parser(lexer, manager, fileName); Parser parser(lexer, manager, fileName);
const auto &program_result = parser.Parse(); const auto &program_result = parser.Parse();
@@ -35,24 +36,31 @@ int main()
} }
Program *program = *program_result; Program *program = *program_result;
Compiler compiler(fileName, manager); Compiler compiler(fileName, manager);
const auto &proto_result = compiler.Compile(program); const auto &comp_result = compiler.Compile(program);
if (!proto_result) if (!comp_result)
{ {
ReportError(proto_result.error(), manager); ReportError(comp_result.error(), manager);
return 1; return 1;
} }
Proto *proto = *proto_result; CompiledModule *compiledModule = *comp_result;
std::cout << "=== Constant Pool ===" << '\n'; size_t cnt = 0;
for (size_t i = 0; i < proto->constants.size(); ++i) for (Proto *proto : compiledModule->protos)
{ {
std::print("[{}] {}\n", i, proto->constants[i].ToString()); std::cout << "=====================\n"
<< "Proto: " << cnt++ << '\n';
std::cout << "=== Constant Pool ===" << '\n';
for (size_t i = 0; i < proto->constants.size(); ++i)
{
std::print("[{}] {}\n", i, proto->constants[i].ToString());
}
DumpCode(proto->code);
std::cout << "\nMax Stack Size: " << (int) proto->maxStack << std::endl;
} }
DumpCode(proto->code);
std::cout << "\nMax Stack Size: " << (int) proto->maxStack << std::endl;
return 0; return 0;
} }

View File

@@ -9,9 +9,8 @@
namespace Fig namespace Fig
{ {
Result<Proto *, Error> Compiler::Compile(Program *program) Result<CompiledModule *, Error> Compiler::Compile(Program *program)
{ {
current->proto = new Proto();
current->freeReg = 0; current->freeReg = 0;
for (Stmt *stmt : program->nodes) for (Stmt *stmt : program->nodes)
@@ -22,7 +21,16 @@ namespace Fig
return std::unexpected(result.error()); return std::unexpected(result.error());
} }
} }
if (mainFuncIndex != -1)
{
std::uint8_t baseReg = AllocReg();
Emit(Op::iABC(OpCode::FastCall, mainFuncIndex, baseReg, 0));
}
Emit(Op::iABC(OpCode::Exit, 0, 0, 0)); // 一定要退出,这是虚拟机退出信号,否则ub Emit(Op::iABC(OpCode::Exit, 0, 0, 0)); // 一定要退出,这是虚拟机退出信号,否则ub
return current->proto;
CompiledModule *compiledModule = new CompiledModule(fileName, allProtos);
return compiledModule;
} }
}; // namespace Fig }; // namespace Fig

View File

@@ -19,7 +19,7 @@
namespace Fig namespace Fig
{ {
// 编译产物 // 编译产物-函数
struct Proto struct Proto
{ {
DynArray<Instruction> code; DynArray<Instruction> code;
@@ -34,7 +34,7 @@ namespace Fig
int depth; // 物理作用域深度(用于 EndScope 释放寄存器) int depth; // 物理作用域深度(用于 EndScope 释放寄存器)
}; };
inline constexpr int MAX_LOCALS = 250; inline constexpr int MAX_LOCALS = 250;
inline constexpr int MAX_CONSTANTS = UINT16_MAX + 1; inline constexpr int MAX_CONSTANTS = UINT16_MAX + 1;
// 任何跨函数、跨模块的编译,都压入弹出这个 State // 任何跨函数、跨模块的编译,都压入弹出这个 State
@@ -58,6 +58,25 @@ namespace Fig
// 注意:这里不 delete proto因为 proto 是要作为编译产物吐出去的 // 注意:这里不 delete proto因为 proto 是要作为编译产物吐出去的
}; };
struct CompiledModule
{
String name; // 供调试/打印
DynArray<Proto *> protos; // 扁平化函数原型
CompiledModule(String _name, DynArray<Proto *> _protos) :
name(std::move(_name)), protos(std::move(_protos))
{
}
~CompiledModule()
{
for (auto *p : protos)
{
delete p;
}
}
};
class Compiler class Compiler
{ {
private: private:
@@ -65,7 +84,13 @@ namespace Fig
SourceManager &manager; SourceManager &manager;
FuncState *current = nullptr; // 永远指向当前正在编译的上下文 FuncState *current = nullptr; // 永远指向当前正在编译的上下文
int mainFuncIndex = -1;
HashMap<int, int> globalFuncMap; // localid -> ProtoIdx
public: public:
DynArray<Proto *> allProtos;
struct FuncStateProtector struct FuncStateProtector
{ {
Compiler *compiler; Compiler *compiler;
@@ -89,6 +114,7 @@ namespace Fig
{ {
// 初始化顶级作用域 // 初始化顶级作用域
current = new FuncState("global", nullptr); current = new FuncState("global", nullptr);
allProtos.push_back(current->proto); // 最顶层, bootstrapper
} }
~Compiler() ~Compiler()
@@ -102,7 +128,8 @@ namespace Fig
} }
} }
Result<Proto *, Error> Compile(Program *program); Result<CompiledModule *, Error> Compile(Program *program);
private: private:
void PushState(String _name) void PushState(String _name)
{ {
@@ -158,7 +185,6 @@ namespace Fig
std::uint16_t AddConstant(Value v) std::uint16_t AddConstant(Value v)
{ {
// TODO: 查重
auto it = auto it =
std::find(current->proto->constants.begin(), current->proto->constants.end(), v); std::find(current->proto->constants.begin(), current->proto->constants.end(), v);
if (it != current->proto->constants.end()) if (it != current->proto->constants.end())
@@ -245,6 +271,8 @@ namespace Fig
Result<std::uint8_t, Error> compileLeftValue( Result<std::uint8_t, Error> compileLeftValue(
Expr *); // 左值对象,可以是变量、结构体字段或模块对象 Expr *); // 左值对象,可以是变量、结构体字段或模块对象
Result<std::uint8_t, Error> compileCallExpr(CallExpr *);
Result<std::uint8_t, Error> compileExpr(Expr *); Result<std::uint8_t, Error> compileExpr(Expr *);
/* Statements */ /* Statements */
@@ -252,6 +280,8 @@ namespace Fig
Result<void, Error> compileBlockStmt(BlockStmt *); Result<void, Error> compileBlockStmt(BlockStmt *);
Result<void, Error> compileIfStmt(IfStmt *); Result<void, Error> compileIfStmt(IfStmt *);
Result<void, Error> compileWhileStmt(WhileStmt *); Result<void, Error> compileWhileStmt(WhileStmt *);
Result<void, Error> compileFnDefStmt(FnDefStmt *);
Result<void, Error> compileReturnStmt(ReturnStmt *);
Result<void, Error> compileStmt(Stmt *); Result<void, Error> compileStmt(Stmt *);
}; };
@@ -271,6 +301,10 @@ namespace Fig
switch (op) switch (op)
{ {
case OpCode::Exit: {
break;
}
case OpCode::Mov: { case OpCode::Mov: {
// iABx 模式 // iABx 模式
std::uint16_t bx = (inst >> 16) & 0xFFFF; std::uint16_t bx = (inst >> 16) & 0xFFFF;
@@ -292,6 +326,22 @@ namespace Fig
break; break;
} }
case OpCode::FastCall:
{
std::uint8_t b = (inst >> 16) & 0xFF;
std::cout << std::format("Proto{:<3} R[{}]+", a, b);
break;
}
case OpCode::Call:
{
std::uint8_t b = (inst >> 16) & 0xFF;
std::cout << std::format("R{:<3} R[{}]+", a, b);
break;
}
case OpCode::LoadTrue:
case OpCode::LoadFalse:
case OpCode::LoadNull:
case OpCode::Add: case OpCode::Add:
case OpCode::Sub: case OpCode::Sub:
case OpCode::Mul: case OpCode::Mul:
@@ -308,17 +358,24 @@ namespace Fig
std::cout << std::format("R{}", a); std::cout << std::format("R{}", a);
break; break;
} }
default: {
std::cout << "?"; case OpCode::LoadFn: {
std::uint16_t bx = (inst >> 16) & 0xFFFF;
std::cout << std::format("R{:<3} Proto[{}]", a, bx);
break; break;
} }
// default: {
// std::cout << "?";
// break;
// }
} }
std::cout << '\n'; std::cout << '\n';
} }
inline void DumpCode(const DynArray<Instruction> &code) inline void DumpCode(const DynArray<Instruction> &code)
{ {
std::cout << "=== Bytecode ===\n"; std::cout << " Bytecode\n";
for (std::size_t i = 0; i < code.size(); ++i) for (std::size_t i = 0; i < code.size(); ++i)
{ {
DisassembleInstruction(code[i], i); DisassembleInstruction(code[i], i);

View File

@@ -65,19 +65,17 @@ namespace Fig
assert("false" && "CompileLiteral: unsupport literal"); assert("false" && "CompileLiteral: unsupport literal");
} }
std::uint8_t targetReg = AllocReg(); std::uint8_t targetReg = AllocReg();
if (current->proto->constants.size() >= MAX_CONSTANTS) if (current->proto->constants.size() >= MAX_CONSTANTS)
{ {
return std::unexpected(Error( return std::unexpected(Error(ErrorType::TooManyConstants,
ErrorType::TooManyConstants,
std::format("constant limit exceeded: {}", MAX_CONSTANTS), std::format("constant limit exceeded: {}", MAX_CONSTANTS),
"How did you write such code? try global variable or split file", "How did you write such code? try global variable or split file",
makeSourceLocation(lit) makeSourceLocation(lit)));
));
} }
std::uint16_t kIndex = AddConstant(v); std::uint16_t kIndex = AddConstant(v);
Emit(Op::iABx(OpCode::LoadK, targetReg, kIndex)); Emit(Op::iABx(OpCode::LoadK, targetReg, kIndex));
return targetReg; return targetReg;
@@ -230,6 +228,75 @@ namespace Fig
return 0; return 0;
} }
} }
Result<std::uint8_t, Error> Compiler::compileCallExpr(CallExpr *expr)
{
bool isStatic = false; // 是否为单纯的 fn(...) 静态函数调用
int protoIdx = -1;
if (expr->callee->type == AstType::IdentiExpr)
{
IdentiExpr *id = static_cast<IdentiExpr *>(expr->callee);
// 如果是函数名且深度为 0 (全局/扁平函数池)
if (id->resolvedType->tag == TypeTag::Function && id->resolvedDepth == 0)
{
if (globalFuncMap.contains(id->localId))
{
isStatic = true;
protoIdx = globalFuncMap[id->localId];
}
}
}
std::uint8_t baseReg = AllocReg();
if (!isStatic)
{
auto calleeRes = compileExpr(expr->callee);
if (!calleeRes)
{
return calleeRes;
}
if (*calleeRes != baseReg)
{
Emit(Op::iABx(OpCode::Mov, baseReg, *calleeRes));
}
}
for (size_t i = 0; i < expr->args.size(); ++i)
{
std::uint8_t argTarget = AllocReg();
auto argRes = compileExpr(expr->args.args[i]);
if (!argRes)
{
return argRes;
}
if (*argRes != argTarget)
{
Emit(Op::iABx(OpCode::Mov, argTarget, *argRes));
}
}
std::uint8_t expectRet = 1;
if (isStatic)
{
Emit(Op::iABC(OpCode::FastCall, (std::uint8_t) protoIdx, baseReg, expectRet));
}
else
{
Emit(Op::iABC(OpCode::Call, baseReg, baseReg, expectRet));
}
for (size_t i = 0; i < expr->args.args.size(); ++i)
{
current->freeReg--;
}
return baseReg; // 返回值起点
}
Result<std::uint8_t, Error> Compiler::compileExpr( Result<std::uint8_t, Error> Compiler::compileExpr(
Expr *expr) // 编译表达式,必定返回一个存放结果的寄存器 ID Expr *expr) // 编译表达式,必定返回一个存放结果的寄存器 ID
{ {
@@ -242,6 +309,7 @@ namespace Fig
case AstType::IdentiExpr: { case AstType::IdentiExpr: {
return compileLeftValue(expr); // 左值直接转换成右值 return compileLeftValue(expr); // 左值直接转换成右值
} }
case AstType::LiteralExpr: { case AstType::LiteralExpr: {
LiteralExpr *lit = static_cast<LiteralExpr *>(expr); LiteralExpr *lit = static_cast<LiteralExpr *>(expr);
@@ -253,9 +321,14 @@ namespace Fig
std::uint8_t targetReg = *result; std::uint8_t targetReg = *result;
return targetReg; return targetReg;
} }
case AstType::InfixExpr: { case AstType::InfixExpr: {
return compileInfixExpr(static_cast<InfixExpr *>(expr)); return compileInfixExpr(static_cast<InfixExpr *>(expr));
} }
case AstType::CallExpr: {
return compileCallExpr(static_cast<CallExpr *>(expr));
}
} }
} }
} // namespace Fig } // namespace Fig

View File

@@ -175,6 +175,67 @@ namespace Fig
return {}; return {};
} }
Result<void, Error> Compiler::compileFnDefStmt(FnDefStmt *stmt)
{
std::uint8_t funcReg = DeclareLocal(stmt->localId);
// 创建子函数编译状态
// 传入 current 作为 enclosing用于后续支持闭包 Upvalue 查找
FuncState childState(stmt->name, current);
allProtos.push_back(childState.proto);
std::uint16_t protoIdx = static_cast<std::uint16_t>(allProtos.size() - 1);
globalFuncMap[stmt->localId] = protoIdx; // 把函数的local id映射到protoIdx
{
FuncStateProtector stateGuard(this, &childState);
AllocReg();
// 将参数映射为子函数的初始局部变量 (R0, R1...)
for (Param *p : stmt->params)
{
PosParam *posParam = static_cast<PosParam *>(p); // TODO: 其他参数支持...
// 按顺序分配寄存器
DeclareLocal(posParam->localId);
}
// B编译函数体语句
auto bodyRes = compileStmt(stmt->body);
if (!bodyRes)
return bodyRes;
// 隐式返回 null
std::uint8_t resReg = AllocReg();
Emit(Op::iABC(OpCode::LoadNull, resReg, 0, 0));
Emit(Op::iABC(OpCode::Return, resReg, 0, 0));
}
// 5. 检查是否是 main 函数
if (stmt->name == U"main")
{
this->mainFuncIndex = protoIdx;
}
Emit(Op::iABx(OpCode::LoadFn, funcReg, protoIdx));
return {};
}
Result<void, Error> Compiler::compileReturnStmt(ReturnStmt *stmt)
{
auto res = compileExpr(stmt->value);
if (!res)
{
return std::unexpected(res.error());
}
Emit(Op::iABC(OpCode::Return, *res, 0, 0));
return {};
}
Result<void, Error> Compiler::compileStmt(Stmt *stmt) // 编译语句 Result<void, Error> Compiler::compileStmt(Stmt *stmt) // 编译语句
{ {
switch (stmt->type) switch (stmt->type)
@@ -206,6 +267,14 @@ namespace Fig
case AstType::WhileStmt: { case AstType::WhileStmt: {
return compileWhileStmt(static_cast<WhileStmt *>(stmt)); return compileWhileStmt(static_cast<WhileStmt *>(stmt));
} }
case AstType::FnDefStmt: {
return compileFnDefStmt(static_cast<FnDefStmt *>(stmt));
}
case AstType::ReturnStmt: {
return compileReturnStmt(static_cast<ReturnStmt *>(stmt));
}
} }
return Result<void, Error>(); return Result<void, Error>();

View File

@@ -9,7 +9,7 @@
namespace Fig namespace Fig
{ {
constexpr String Value::ToString() const String Value::ToString() const
{ {
if (IsNull()) if (IsNull())
{ {

View File

@@ -182,7 +182,7 @@ namespace Fig
// 类函数 // 类函数
[[nodiscard]] [[nodiscard]]
constexpr String ToString() const; String ToString() const;
}; };
/* /*

View File

@@ -282,7 +282,8 @@ namespace Fig
} }
TypeInfo *valueType = stmt->value->resolvedType; TypeInfo *valueType = stmt->value->resolvedType;
if (currentReturnType != typeCtx.GetAny() && currentReturnType != valueType)
if (!currentReturnType->isAny() && !valueType->isAny() && currentReturnType != valueType)
{ {
return std::unexpected(Error(ErrorType::TypeError, return std::unexpected(Error(ErrorType::TypeError,
std::format("return type mismatch: expects '{}', got `{}`", std::format("return type mismatch: expects '{}', got `{}`",
@@ -500,6 +501,36 @@ namespace Fig
return {}; return {};
} }
Result<void, Error> Analyzer::analyzeCallExpr(CallExpr *expr)
{
auto calleeRes = analyzeExpr(expr->callee);
if (!calleeRes)
{
return calleeRes;
}
if (expr->callee->resolvedType != typeCtx.GetAny()
&& expr->callee->resolvedType != typeCtx.GetFunction())
{
return std::unexpected(Error(ErrorType::TypeError,
std::format("object `{}` is not callable", expr->callee->toString()),
"none",
makeSourceLocation(expr->callee)));
}
for (auto *arg : expr->args.args)
{
auto argRes = analyzeExpr(arg);
if (!argRes)
{
return argRes;
}
}
expr->resolvedType = typeCtx.GetAny();
return {};
}
Result<void, Error> Analyzer::analyzeStmt(Stmt *stmt) Result<void, Error> Analyzer::analyzeStmt(Stmt *stmt)
{ {
if (!stmt) if (!stmt)
@@ -596,10 +627,17 @@ namespace Fig
return {}; return {};
} }
case AstType::IdentiExpr: return analyzeIdentiExpr(static_cast<IdentiExpr *>(expr)); case AstType::IdentiExpr: {
return analyzeIdentiExpr(static_cast<IdentiExpr *>(expr));
}
case AstType::InfixExpr: case AstType::InfixExpr: {
return analyzeInfixExpr(static_cast<InfixExpr *>(expr)); return analyzeInfixExpr(static_cast<InfixExpr *>(expr));
}
case AstType::CallExpr: {
return analyzeCallExpr(static_cast<CallExpr *>(expr));
}
// TODO: PrefixExpr (前缀), CallExpr (函数调用), MemberExpr (属性访问) // TODO: PrefixExpr (前缀), CallExpr (函数调用), MemberExpr (属性访问)

View File

@@ -87,6 +87,7 @@ namespace Fig
Result<void, Error> analyzeIdentiExpr(IdentiExpr *); Result<void, Error> analyzeIdentiExpr(IdentiExpr *);
Result<void, Error> analyzeInfixExpr(InfixExpr *); Result<void, Error> analyzeInfixExpr(InfixExpr *);
Result<void, Error> analyzeCallExpr(CallExpr *);
Result<void, Error> analyzeStmt(Stmt *); Result<void, Error> analyzeStmt(Stmt *);
Result<void, Error> analyzeExpr(Expr *); Result<void, Error> analyzeExpr(Expr *);

View File

@@ -7,79 +7,85 @@
#include <VM/VM.hpp> #include <VM/VM.hpp>
#define BINARY_ARITHMETIC_OP(opCode, op) \ #define BINARY_ARITHMETIC_OP(opCode, op) \
case OpCode::opCode: { \ case OpCode::opCode: { \
std::uint8_t b = decodeB(inst); \ std::uint8_t b = decodeB(inst); \
std::uint8_t c = decodeC(inst); \ std::uint8_t c = decodeC(inst); \
Value lhs = registers[b]; \ Value lhs = currentFrame->registerBase[b]; \
Value rhs = registers[c]; \ Value rhs = currentFrame->registerBase[c]; \
if (lhs.IsInt() && rhs.IsInt()) [[likely]] \ if (lhs.IsInt() && rhs.IsInt()) [[likely]] \
{ \ { \
registers[a] = Value::FromInt(lhs.AsInt() op rhs.AsInt()); \ currentFrame->registerBase[a] = Value::FromInt(lhs.AsInt() op rhs.AsInt()); \
} \ } \
else if (lhs.IsDouble() && rhs.IsDouble()) [[likely]] \ else if (lhs.IsDouble() && rhs.IsDouble()) [[likely]] \
{ \ { \
registers[a] = Value::FromDouble(lhs.AsDouble() op rhs.AsDouble()); \ currentFrame->registerBase[a] = Value::FromDouble(lhs.AsDouble() op rhs.AsDouble()); \
} \ } \
/* 隐式类型提升Int 与 Double 混合运算 */ \ /* 隐式类型提升Int 与 Double 混合运算 */ \
else if (lhs.IsInt() && rhs.IsDouble()) [[likely]] \ else if (lhs.IsInt() && rhs.IsDouble()) [[likely]] \
{ \ { \
registers[a] = Value::FromDouble(lhs.AsInt() op rhs.AsDouble()); \ currentFrame->registerBase[a] = Value::FromDouble(lhs.AsInt() op rhs.AsDouble()); \
} \ } \
else if (lhs.IsDouble() && rhs.IsInt()) [[likely]] \ else if (lhs.IsDouble() && rhs.IsInt()) [[likely]] \
{ \ { \
registers[a] = Value::FromDouble(lhs.AsDouble() op rhs.AsInt()); \ currentFrame->registerBase[a] = Value::FromDouble(lhs.AsDouble() op rhs.AsInt()); \
} \ } \
else \ else \
{ \ { \
assert(false && "VM Runtime Error: Unsupported types for arithmetic operation"); \ assert(false && "VM Runtime Error: Unsupported types for arithmetic operation"); \
} \ } \
break; \ break; \
} }
#define BINARY_COMPARE_OP(opCode, op) \ #define BINARY_COMPARE_OP(opCode, op) \
case OpCode::opCode: { \ case OpCode::opCode: { \
std::uint8_t b = decodeB(inst); \ std::uint8_t b = decodeB(inst); \
std::uint8_t c = decodeC(inst); \ std::uint8_t c = decodeC(inst); \
Value lhs = registers[b]; \ Value lhs = currentFrame->registerBase[b]; \
Value rhs = registers[c]; \ Value rhs = currentFrame->registerBase[c]; \
if (lhs.IsInt() && rhs.IsInt()) [[likely]] \ if (lhs.IsInt() && rhs.IsInt()) [[likely]] \
{ \ { \
registers[a] = (lhs.AsInt() op rhs.AsInt()) ? Value::GetTrueInstance() : Value::GetFalseInstance(); \ currentFrame->registerBase[a] = (lhs.AsInt() op rhs.AsInt()) ? \
} \ Value::GetTrueInstance() : \
else if (lhs.IsDouble() && rhs.IsDouble()) [[likely]] \ Value::GetFalseInstance(); \
{ \ } \
registers[a] = (lhs.AsDouble() op rhs.AsDouble()) ? Value::GetTrueInstance() : Value::GetFalseInstance(); \ else if (lhs.IsDouble() && rhs.IsDouble()) [[likely]] \
} \ { \
else if (lhs.IsInt() && rhs.IsDouble()) [[likely]] \ currentFrame->registerBase[a] = (lhs.AsDouble() op rhs.AsDouble()) ? \
{ \ Value::GetTrueInstance() : \
registers[a] = (lhs.AsInt() op rhs.AsDouble()) ? Value::GetTrueInstance() : Value::GetFalseInstance(); \ Value::GetFalseInstance(); \
} \ } \
else if (lhs.IsDouble() && rhs.IsInt()) [[likely]] \ else if (lhs.IsInt() && rhs.IsDouble()) [[likely]] \
{ \ { \
registers[a] = (lhs.AsDouble() op rhs.AsInt()) ? Value::GetTrueInstance() : Value::GetFalseInstance(); \ currentFrame->registerBase[a] = (lhs.AsInt() op rhs.AsDouble()) ? \
} \ Value::GetTrueInstance() : \
else \ Value::GetFalseInstance(); \
{ \ } \
/* TODO: 非数字比较 */ \ else if (lhs.IsDouble() && rhs.IsInt()) [[likely]] \
assert(false && "VM Runtime Error: Unsupported types for comparison"); \ { \
} \ currentFrame->registerBase[a] = (lhs.AsDouble() op rhs.AsInt()) ? \
break; \ Value::GetTrueInstance() : \
Value::GetFalseInstance(); \
} \
else \
{ \
/* TODO: 非数字比较 */ \
assert(false && "VM Runtime Error: Unsupported types for comparison"); \
} \
break; \
} }
namespace Fig namespace Fig
{ {
Result<Value, Error> VM::Execute(Proto *proto) Result<Value, Error> VM::Execute(CompiledModule *compiledModule)
{ {
// 指令指针 (Instruction Pointer / PC) 和 常量池指针 Proto *entry = compiledModule->protos[0];
const Instruction *ip = proto->code.data(); pushFrame(entry, registers);
const Value *k = proto->constants.data();
// 核心解释器循环 (The Dispatch Loop)
while (true) while (true)
{ {
// 取指并递增指针 // 取指并递增指针
Instruction inst = *ip++; Instruction inst = *(currentFrame->ip++);
// 解码 OpCode 和 A 操作数 // 解码 OpCode 和 A 操作数
OpCode op = decodeOpCode(inst); OpCode op = decodeOpCode(inst);
@@ -92,34 +98,63 @@ namespace Fig
case OpCode::LoadK: { case OpCode::LoadK: {
std::uint16_t bx = decodeBx(inst); std::uint16_t bx = decodeBx(inst);
registers[a] = k[bx]; // constants currentFrame->registerBase[a] = currentFrame->getConstant(bx); // constants
break;
}
case OpCode::LoadTrue: {
currentFrame->registerBase[a] = Value::GetTrueInstance();
break;
}
case OpCode::LoadFalse: {
currentFrame->registerBase[a] = Value::GetFalseInstance();
break;
}
case OpCode::LoadNull: {
currentFrame->registerBase[a] = Value::GetNullInstance();
break;
}
case OpCode::FastCall: {
Proto *proto = compiledModule->protos[a];
std::uint8_t baseReg = decodeB(inst);
pushFrame(proto, currentFrame->registerBase + baseReg);
break;
}
case OpCode::Call: {
break; break;
} }
case OpCode::Return: { case OpCode::Return: {
return registers[a]; *currentFrame->registerBase = currentFrame->registerBase[a];
popFrame();
break;
} }
case OpCode::Jmp: { case OpCode::Jmp: {
std::int16_t sbx = decodeSBx(inst); std::int16_t sbx = decodeSBx(inst);
ip += sbx; currentFrame->ip += sbx;
break; break;
} }
case OpCode::JmpIfFalse: { case OpCode::JmpIfFalse: {
Value &v = registers[a]; Value &v = currentFrame->registerBase[a];
bool cond = v.AsBool(); // 条件类型 Compiler检查 bool cond = v.AsBool(); // 条件类型 Compiler检查
if (!cond) if (!cond)
{ {
std::int16_t sbx = decodeSBx(inst); std::int16_t sbx = decodeSBx(inst);
ip += sbx; currentFrame->ip += sbx;
} }
break; break;
} }
case OpCode::Mov: { case OpCode::Mov: {
std::uint16_t bx = decodeBx(inst); std::uint16_t bx = decodeBx(inst);
registers[a] = registers[bx]; currentFrame->registerBase[a] = currentFrame->registerBase[bx];
break; break;
} }
@@ -136,9 +171,9 @@ namespace Fig
BINARY_COMPARE_OP(LessEqual, <=); BINARY_COMPARE_OP(LessEqual, <=);
default: { // default: {
assert(false && "VM: Unknown OpCode encountered!"); // assert(false && "VM: Unknown OpCode encountered!");
} // }
} }
} }
return Value::GetNullInstance(); return Value::GetNullInstance();

View File

@@ -16,6 +16,18 @@
namespace Fig namespace Fig
{ {
struct CallFrame
{
Proto *proto; // 当前执行的原型
Instruction *ip; // 当前指令指针
Value *registerBase; // 寄存器起点
inline Value getConstant(std::uint16_t idx)
{
return proto->constants[idx];
}
};
class VM class VM
{ {
private: private:
@@ -24,6 +36,8 @@ namespace Fig
// 一次性分配 // 一次性分配
Value registers[MAX_REGISTERS]; Value registers[MAX_REGISTERS];
DynArray<CallFrame> frames;
CallFrame *currentFrame;
public: public:
VM() VM()
{ {
@@ -34,6 +48,26 @@ namespace Fig
} }
private: private:
void pushFrame(Proto *proto, Value *base)
{
frames.push_back({
proto,
proto->code.data(),
base
});
currentFrame = &frames.back();
}
void popFrame()
{
frames.pop_back();
if (!frames.empty())
{
currentFrame = &frames.back();
}
}
inline OpCode decodeOpCode(Instruction inst) inline OpCode decodeOpCode(Instruction inst)
{ {
return static_cast<OpCode>(inst & 0xFF); return static_cast<OpCode>(inst & 0xFF);
@@ -61,7 +95,7 @@ namespace Fig
public: public:
// 执行入口:接收 Proto // 执行入口:接收 Proto
Result<Value, Error> Execute(Proto *proto); Result<Value, Error> Execute(CompiledModule *);
inline void PrintRegisters() inline void PrintRegisters()
{ {

View File

@@ -50,34 +50,53 @@ int main()
std::cout << "analyzer: Program OK, PASSED\n"; std::cout << "analyzer: Program OK, PASSED\n";
Compiler compiler(fileName, manager); Compiler compiler(fileName, manager);
const auto &proto_result = compiler.Compile(program); const auto &comp_result = compiler.Compile(program);
if (!proto_result) if (!comp_result)
{ {
ReportError(proto_result.error(), manager); ReportError(comp_result.error(), manager);
return 1; return 1;
} }
Proto *proto = *proto_result; CompiledModule *compiledModule = *comp_result;
std::cout << "=== Constant Pool ===" << '\n'; size_t cnt = 0;
for (size_t i = 0; i < proto->constants.size(); ++i) for (Proto *proto : compiledModule->protos)
{ {
std::print("[{}] {}\n", i, proto->constants[i].ToString()); std::cout << "\n"
<< "Proto: " << cnt++ << '\n';
std::cout << " Constant Pool" << '\n';
for (size_t i = 0; i < proto->constants.size(); ++i)
{
std::print("[{}] {}\n", i, proto->constants[i].ToString());
}
DumpCode(proto->code);
std::cout << "\nMax Stack Size: " << (int) proto->maxStack << std::endl;
} }
DumpCode(proto->code);
std::cout << "\nMax Stack Size: " << (int) proto->maxStack << std::endl;
VM vm; VM vm;
auto result_ = vm.Execute(proto); using Clock = std::chrono::high_resolution_clock;
Clock clock;
auto start = clock.now();
auto result_ = vm.Execute(compiledModule);
auto end = clock.now();
auto duration = end - start;
if (!result_) if (!result_)
{ {
ReportError(result_.error(), manager); ReportError(result_.error(), manager);
return 1; return 1;
} }
Value result = *result_; Value result = *result_;
std::cout << "result: " << result.ToString() << "\n"; std::cout << "result: " << result.ToString() << "\n";
std::cout << "execution cost: " << std::chrono::duration_cast<std::chrono::milliseconds>(duration).count() << "ms\n";
vm.PrintRegisters(); vm.PrintRegisters();
} }