Skip to content

compiler: refactor globals and named types #401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (

// emitMakeChan returns a new channel value for the given channel type.
func (c *Compiler) emitMakeChan(expr *ssa.MakeChan) (llvm.Value, error) {
chanType := c.mod.GetTypeByName("runtime.channel")
chanType := c.getLLVMType(c.getRuntimeType("channel"))
size := c.targetData.TypeAllocSize(chanType)
sizeValue := llvm.ConstInt(c.uintptrType, size, false)
ptr := c.createRuntimeCall("alloc", []llvm.Value{sizeValue}, "chan.alloc")
Expand Down
91 changes: 42 additions & 49 deletions compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package compiler
import (
"errors"
"fmt"
"go/ast"
"go/build"
"go/constant"
"go/token"
Expand Down Expand Up @@ -66,6 +67,7 @@ type Compiler struct {
interfaceInvokeWrappers []interfaceInvokeWrapper
ir *ir.Program
diagnostics []error
astComments map[string]*ast.CommentGroup
}

type Frame struct {
Expand Down Expand Up @@ -275,39 +277,7 @@ func (c *Compiler) Compile(mainPath string) []error {

var frames []*Frame

// Declare all named struct types.
for _, t := range c.ir.NamedTypes {
if named, ok := t.Type.Type().(*types.Named); ok {
if _, ok := named.Underlying().(*types.Struct); ok {
t.LLVMType = c.ctx.StructCreateNamed(named.Obj().Pkg().Path() + "." + named.Obj().Name())
}
}
}

// Define all named struct types.
for _, t := range c.ir.NamedTypes {
if named, ok := t.Type.Type().(*types.Named); ok {
if st, ok := named.Underlying().(*types.Struct); ok {
llvmType := c.getLLVMType(st)
t.LLVMType.StructSetBody(llvmType.StructElementTypes(), false)
}
}
}

// Declare all globals.
for _, g := range c.ir.Globals {
typ := g.Type().(*types.Pointer).Elem()
llvmType := c.getLLVMType(typ)
global := c.mod.NamedGlobal(g.LinkName())
if global.IsNil() {
global = llvm.AddGlobal(c.mod, llvmType, g.LinkName())
}
g.LLVMGlobal = global
if !g.IsExtern() {
global.SetLinkage(llvm.InternalLinkage)
global.SetInitializer(c.getZeroValue(llvmType))
}
}
c.loadASTComments(lprogram)

// Declare all functions.
for _, f := range c.ir.Functions {
Expand Down Expand Up @@ -413,6 +383,22 @@ func (c *Compiler) Compile(mainPath string) []error {
return c.diagnostics
}

// getRuntimeType obtains a named type from the runtime package and returns it
// as a Go type.
func (c *Compiler) getRuntimeType(name string) types.Type {
return c.ir.Program.ImportedPackage("runtime").Type(name).Type()
}

// getLLVMRuntimeType obtains a named type from the runtime package and returns
// it as a LLVM type, creating it if necessary. It is a shorthand for
// getLLVMType(getRuntimeType(name)).
func (c *Compiler) getLLVMRuntimeType(name string) llvm.Type {
return c.getLLVMType(c.getRuntimeType(name))
}

// getLLVMType creates and returns a LLVM type for a Go type. In the case of
// named struct types (or Go types implemented as named LLVM structs such as
// strings) it also creates it first if necessary.
func (c *Compiler) getLLVMType(goType types.Type) llvm.Type {
switch typ := goType.(type) {
case *types.Array:
Expand Down Expand Up @@ -441,7 +427,7 @@ func (c *Compiler) getLLVMType(goType types.Type) llvm.Type {
case types.Complex128:
return c.ctx.StructType([]llvm.Type{c.ctx.DoubleType(), c.ctx.DoubleType()}, false)
case types.String, types.UntypedString:
return c.mod.GetTypeByName("runtime._string")
return c.getLLVMRuntimeType("_string")
case types.Uintptr:
return c.uintptrType
case types.UnsafePointer:
Expand All @@ -450,16 +436,23 @@ func (c *Compiler) getLLVMType(goType types.Type) llvm.Type {
panic("unknown basic type: " + typ.String())
}
case *types.Chan:
return llvm.PointerType(c.mod.GetTypeByName("runtime.channel"), 0)
return llvm.PointerType(c.getLLVMRuntimeType("channel"), 0)
case *types.Interface:
return c.mod.GetTypeByName("runtime._interface")
return c.getLLVMRuntimeType("_interface")
case *types.Map:
return llvm.PointerType(c.mod.GetTypeByName("runtime.hashmap"), 0)
return llvm.PointerType(c.getLLVMRuntimeType("hashmap"), 0)
case *types.Named:
if _, ok := typ.Underlying().(*types.Struct); ok {
llvmType := c.mod.GetTypeByName(typ.Obj().Pkg().Path() + "." + typ.Obj().Name())
if st, ok := typ.Underlying().(*types.Struct); ok {
// Structs are a special case. While other named types are ignored
// in LLVM IR, named structs are implemented as named structs in
// LLVM. This is because it is otherwise impossible to create
// self-referencing types such as linked lists.
llvmName := typ.Obj().Pkg().Path() + "." + typ.Obj().Name()
llvmType := c.mod.GetTypeByName(llvmName)
if llvmType.IsNil() {
panic("underlying type not found: " + typ.Obj().Pkg().Path() + "." + typ.Obj().Name())
llvmType = c.ctx.StructCreateNamed(llvmName)
underlying := c.getLLVMType(st)
llvmType.StructSetBody(underlying.StructElementTypes(), false)
}
return llvmType
}
Expand Down Expand Up @@ -1356,9 +1349,9 @@ func (c *Compiler) getValue(frame *Frame, expr ssa.Value) llvm.Value {
}
return c.createFuncValue(fn.LLVMFn, llvm.Undef(c.i8ptrType), fn.Signature)
case *ssa.Global:
value := c.ir.GetGlobal(expr).LLVMGlobal
value := c.getGlobal(expr)
if value.IsNil() {
c.addError(expr.Pos(), "global not found: "+c.ir.GetGlobal(expr).LinkName())
c.addError(expr.Pos(), "global not found: "+expr.RelString(nil))
return llvm.Undef(c.getLLVMType(expr.Type()))
}
return value
Expand Down Expand Up @@ -1696,9 +1689,9 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
var iteratorType llvm.Type
switch typ := expr.X.Type().Underlying().(type) {
case *types.Basic: // string
iteratorType = c.mod.GetTypeByName("runtime.stringIterator")
iteratorType = c.getLLVMRuntimeType("stringIterator")
case *types.Map:
iteratorType = c.mod.GetTypeByName("runtime.hashmapIterator")
iteratorType = c.getLLVMRuntimeType("hashmapIterator")
default:
panic("unknown type in range: " + typ.String())
}
Expand Down Expand Up @@ -1858,7 +1851,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {

newPtr := c.builder.CreateInBoundsGEP(oldPtr, []llvm.Value{low}, "")
newLen := c.builder.CreateSub(high, low, "")
str := llvm.Undef(c.mod.GetTypeByName("runtime._string"))
str := llvm.Undef(c.getLLVMRuntimeType("_string"))
str = c.builder.CreateInsertValue(str, newPtr, 0, "")
str = c.builder.CreateInsertValue(str, newLen, 1, "")
return str, nil
Expand Down Expand Up @@ -2237,7 +2230,7 @@ func (c *Compiler) parseConst(prefix string, expr *ssa.Const) llvm.Value {
global.SetUnnamedAddr(true)
zero := llvm.ConstInt(c.ctx.Int32Type(), 0, false)
strPtr := c.builder.CreateInBoundsGEP(global, []llvm.Value{zero, zero}, "")
strObj := llvm.ConstNamedStruct(c.mod.GetTypeByName("runtime._string"), []llvm.Value{strPtr, strLen})
strObj := llvm.ConstNamedStruct(c.getLLVMRuntimeType("_string"), []llvm.Value{strPtr, strLen})
return strObj
} else if typ.Kind() == types.UnsafePointer {
if !expr.IsNil() {
Expand Down Expand Up @@ -2290,7 +2283,7 @@ func (c *Compiler) parseConst(prefix string, expr *ssa.Const) llvm.Value {
llvm.ConstInt(c.uintptrType, 0, false),
llvm.ConstPointerNull(c.i8ptrType),
}
return llvm.ConstNamedStruct(c.mod.GetTypeByName("runtime._interface"), fields)
return llvm.ConstNamedStruct(c.getLLVMRuntimeType("_interface"), fields)
case *types.Pointer:
if expr.Value != nil {
panic("expected nil pointer constant")
Expand Down Expand Up @@ -2510,8 +2503,8 @@ func (c *Compiler) parseUnOp(frame *Frame, unop *ssa.UnOp) (llvm.Value, error) {
// var C.add unsafe.Pointer
// Instead of a load from the global, create a bitcast of the
// function pointer itself.
global := c.ir.GetGlobal(unop.X.(*ssa.Global))
name := global.LinkName()[:len(global.LinkName())-len("$funcaddr")]
globalName := c.getGlobalInfo(unop.X.(*ssa.Global)).linkName
name := globalName[:len(globalName)-len("$funcaddr")]
fn := c.mod.NamedFunction(name)
if fn.IsNil() {
return llvm.Value{}, c.makeError(unop.Pos(), "cgo function not found: "+name)
Expand Down
8 changes: 4 additions & 4 deletions compiler/defer.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (c *Compiler) deferInitFunc(frame *Frame) {
frame.deferClosureFuncs = make(map[*ir.Function]int)

// Create defer list pointer.
deferType := llvm.PointerType(c.mod.GetTypeByName("runtime._defer"), 0)
deferType := llvm.PointerType(c.getLLVMRuntimeType("_defer"), 0)
frame.deferPtr = c.builder.CreateAlloca(deferType, "deferPtr")
c.builder.CreateStore(llvm.ConstPointerNull(deferType), frame.deferPtr)
}
Expand Down Expand Up @@ -200,7 +200,7 @@ func (c *Compiler) emitRunDefers(frame *Frame) {
}

// Get the real defer struct type and cast to it.
valueTypes := []llvm.Type{c.uintptrType, llvm.PointerType(c.mod.GetTypeByName("runtime._defer"), 0), c.i8ptrType}
valueTypes := []llvm.Type{c.uintptrType, llvm.PointerType(c.getLLVMRuntimeType("_defer"), 0), c.i8ptrType}
for _, arg := range callback.Args {
valueTypes = append(valueTypes, c.getLLVMType(arg.Type()))
}
Expand Down Expand Up @@ -231,7 +231,7 @@ func (c *Compiler) emitRunDefers(frame *Frame) {
// Direct call.

// Get the real defer struct type and cast to it.
valueTypes := []llvm.Type{c.uintptrType, llvm.PointerType(c.mod.GetTypeByName("runtime._defer"), 0)}
valueTypes := []llvm.Type{c.uintptrType, llvm.PointerType(c.getLLVMRuntimeType("_defer"), 0)}
for _, param := range callback.Params {
valueTypes = append(valueTypes, c.getLLVMType(param.Type()))
}
Expand Down Expand Up @@ -260,7 +260,7 @@ func (c *Compiler) emitRunDefers(frame *Frame) {
case *ssa.MakeClosure:
// Get the real defer struct type and cast to it.
fn := c.ir.GetFunction(callback.Fn.(*ssa.Function))
valueTypes := []llvm.Type{c.uintptrType, llvm.PointerType(c.mod.GetTypeByName("runtime._defer"), 0)}
valueTypes := []llvm.Type{c.uintptrType, llvm.PointerType(c.getLLVMRuntimeType("_defer"), 0)}
params := fn.Signature.Params()
for i := 0; i < params.Len(); i++ {
valueTypes = append(valueTypes, c.getLLVMType(params.At(i).Type()))
Expand Down
2 changes: 1 addition & 1 deletion compiler/func-lowering.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (c *Compiler) LowerFuncValues() {
}

// Find all func values used in the program with their signatures.
funcValueWithSignaturePtr := llvm.PointerType(c.mod.GetTypeByName("runtime.funcValueWithSignature"), 0)
funcValueWithSignaturePtr := llvm.PointerType(c.getLLVMRuntimeType("funcValueWithSignature"), 0)
signatures := map[string]*funcSignatureInfo{}
for global := c.mod.FirstGlobal(); !global.IsNil(); global = llvm.NextGlobal(global) {
if global.Type() != funcValueWithSignaturePtr {
Expand Down
4 changes: 2 additions & 2 deletions compiler/func.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (c *Compiler) createFuncValue(funcPtr, context llvm.Value, sig *types.Signa
funcValueWithSignatureGlobalName := funcPtr.Name() + "$withSignature"
funcValueWithSignatureGlobal := c.mod.NamedGlobal(funcValueWithSignatureGlobalName)
if funcValueWithSignatureGlobal.IsNil() {
funcValueWithSignatureType := c.mod.GetTypeByName("runtime.funcValueWithSignature")
funcValueWithSignatureType := c.getLLVMRuntimeType("funcValueWithSignature")
funcValueWithSignature := llvm.ConstNamedStruct(funcValueWithSignatureType, []llvm.Value{
llvm.ConstPtrToInt(funcPtr, c.uintptrType),
sigGlobal,
Expand Down Expand Up @@ -126,7 +126,7 @@ func (c *Compiler) getFuncType(typ *types.Signature) llvm.Type {
rawPtr := c.getRawFuncType(typ)
return c.ctx.StructType([]llvm.Type{c.i8ptrType, rawPtr}, false)
case funcValueSwitch:
return c.mod.GetTypeByName("runtime.funcValue")
return c.getLLVMRuntimeType("funcValue")
default:
panic("unimplemented func value variant")
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/goroutine-lowering.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ func (c *Compiler) markAsyncFunctions() (needsScheduler bool, err error) {

// Coroutine setup.
c.builder.SetInsertPointBefore(f.EntryBasicBlock().FirstInstruction())
taskState := c.builder.CreateAlloca(c.mod.GetTypeByName("runtime.taskState"), "task.state")
taskState := c.builder.CreateAlloca(c.getLLVMRuntimeType("taskState"), "task.state")
stateI8 := c.builder.CreateBitCast(taskState, c.i8ptrType, "task.state.i8")
id := c.builder.CreateCall(coroIdFunc, []llvm.Value{
llvm.ConstInt(c.ctx.Int32Type(), 0, false),
Expand Down
4 changes: 2 additions & 2 deletions compiler/interface-lowering.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ func (c *Compiler) LowerInterfaces() {
// run runs the pass itself.
func (p *lowerInterfacesPass) run() {
// Collect all type codes.
typecodeIDPtr := llvm.PointerType(p.mod.GetTypeByName("runtime.typecodeID"), 0)
typeInInterfacePtr := llvm.PointerType(p.mod.GetTypeByName("runtime.typeInInterface"), 0)
typecodeIDPtr := llvm.PointerType(p.getLLVMRuntimeType("typecodeID"), 0)
typeInInterfacePtr := llvm.PointerType(p.getLLVMRuntimeType("typeInInterface"), 0)
var typesInInterfaces []llvm.Value
for global := p.mod.FirstGlobal(); !global.IsNil(); global = llvm.NextGlobal(global) {
switch global.Type() {
Expand Down
10 changes: 5 additions & 5 deletions compiler/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ func (c *Compiler) parseMakeInterface(val llvm.Value, typ types.Type, pos token.
itfMethodSetGlobal := c.getTypeMethodSet(typ)
itfConcreteTypeGlobal := c.mod.NamedGlobal("typeInInterface:" + itfTypeCodeGlobal.Name())
if itfConcreteTypeGlobal.IsNil() {
typeInInterface := c.mod.GetTypeByName("runtime.typeInInterface")
typeInInterface := c.getLLVMRuntimeType("typeInInterface")
itfConcreteTypeGlobal = llvm.AddGlobal(c.mod, typeInInterface, "typeInInterface:"+itfTypeCodeGlobal.Name())
itfConcreteTypeGlobal.SetInitializer(llvm.ConstNamedStruct(typeInInterface, []llvm.Value{itfTypeCodeGlobal, itfMethodSetGlobal}))
itfConcreteTypeGlobal.SetGlobalConstant(true)
itfConcreteTypeGlobal.SetLinkage(llvm.PrivateLinkage)
}
itfTypeCode := c.builder.CreatePtrToInt(itfConcreteTypeGlobal, c.uintptrType, "")
itf := llvm.Undef(c.mod.GetTypeByName("runtime._interface"))
itf := llvm.Undef(c.getLLVMRuntimeType("_interface"))
itf = c.builder.CreateInsertValue(itf, itfTypeCode, 0, "")
itf = c.builder.CreateInsertValue(itf, itfValue, 1, "")
return itf
Expand All @@ -48,7 +48,7 @@ func (c *Compiler) getTypeCode(typ types.Type) llvm.Value {
globalName := "type:" + getTypeCodeName(typ)
global := c.mod.NamedGlobal(globalName)
if global.IsNil() {
global = llvm.AddGlobal(c.mod, c.mod.GetTypeByName("runtime.typecodeID"), globalName)
global = llvm.AddGlobal(c.mod, c.getLLVMRuntimeType("typecodeID"), globalName)
global.SetGlobalConstant(true)
}
return global
Expand Down Expand Up @@ -163,11 +163,11 @@ func (c *Compiler) getTypeMethodSet(typ types.Type) llvm.Value {
ms := c.ir.Program.MethodSets.MethodSet(typ)
if ms.Len() == 0 {
// no methods, so can leave that one out
return llvm.ConstPointerNull(llvm.PointerType(c.mod.GetTypeByName("runtime.interfaceMethodInfo"), 0))
return llvm.ConstPointerNull(llvm.PointerType(c.getLLVMRuntimeType("interfaceMethodInfo"), 0))
}

methods := make([]llvm.Value, ms.Len())
interfaceMethodInfoType := c.mod.GetTypeByName("runtime.interfaceMethodInfo")
interfaceMethodInfoType := c.getLLVMRuntimeType("interfaceMethodInfo")
for i := 0; i < ms.Len(); i++ {
method := ms.At(i)
signatureGlobal := c.getMethodSignature(method.Obj().(*types.Func))
Expand Down
Loading