Skip to content

Commit

Permalink
generate: iterate on function support
Browse files Browse the repository at this point in the history
  • Loading branch information
tmc committed Sep 3, 2023
1 parent 5a51446 commit 9df8d62
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 32 deletions.
19 changes: 18 additions & 1 deletion generate/codegen/gen_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func (f *Function) GoReturn(currentModule *modules.Module) string {
if f.ReturnType == nil {
return ""
}
log.Printf("rendering GoReturn function return: %s %T", f.ReturnType, f.ReturnType)
return f.ReturnType.GoName(currentModule, true)
}

Expand Down Expand Up @@ -127,10 +128,26 @@ func (f *Function) WriteGoCallCode(currentModule *modules.Module, cw *CodeWriter
cw.WriteLine("}")
}

func (f *Function) WriteObjcWrapper(currentModule *modules.Module, cw *CodeWriter) {
if f.Deprecated {
return
cw.WriteLine("// deprecated")
}
returnTypeStr := f.Type.ReturnType.CName()
cw.WriteLineF("%v %v(%v) {", returnTypeStr, f.GoName, f.CArgs(currentModule))
cw.Indent()
var args []string
for _, p := range f.Parameters {
args = append(args, p.Name)
}
cw.WriteLineF("return %v(%v);", f.Type.Name, strings.Join(args, ", "))
cw.UnIndent()
cw.WriteLine("}")
}

func (f *Function) WriteCSignature(currentModule *modules.Module, cw *CodeWriter) {
var returnTypeStr string
rt := f.Type.ReturnType
log.Printf("rt: %T", rt)
returnTypeStr = rt.CName()
cw.WriteLineF("// %v %v(%v); ", returnTypeStr, f.GoName, f.CArgs(currentModule))
}
Expand Down
2 changes: 1 addition & 1 deletion generate/codegen/gen_struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (

// Struct is code generator for objective-c struct
type Struct struct {
Type *typing.StructType
Type typing.Type
Name string // the first part of objc function name
GoName string
Deprecated bool // if has been deprecated
Expand Down
46 changes: 37 additions & 9 deletions generate/codegen/modulewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func (m *ModuleWriter) WriteCode() {
m.WriteTypeAliases()
m.WriteStructs()
m.WriteFunctions()
m.WriteFunctionWrappers()
if m.Module.Package == "coreimage" {
// filter protocols maybe arent "real" protocols?
// get "cannot find protocol declaration" with protocol imports
Expand Down Expand Up @@ -125,19 +126,20 @@ func (m *ModuleWriter) WriteStructs() {
cw.WriteLine(")")

for _, s := range m.Structs {
if s.DocURL != "" {
cw.WriteLine(fmt.Sprintf("// %s [Full Topic]", s.Description))
cw.WriteLine(fmt.Sprintf("//\n// [Full Topic]: %s", s.DocURL))
}

// if Ref type, allias to unsafe.Pointer
if strings.HasSuffix(s.Name, "Ref") {
if s.DocURL != "" {
cw.WriteLine(fmt.Sprintf("// %s [Full Topic]", s.Description))
cw.WriteLine(fmt.Sprintf("//\n// [Full Topic]: %s", s.DocURL))
}

cw.WriteLineF("type %s unsafe.Pointer", s.GoName)
continue
}
}
}

// WriteFunctions writes the go code to call exposed functions.
func (m *ModuleWriter) WriteFunctions() {
if len(m.Functions) == 0 {
return
Expand All @@ -156,10 +158,10 @@ func (m *ModuleWriter) WriteFunctions() {
cw.WriteLine("package " + m.Module.Package)

//TODO: determine imports from functions
cw.WriteLine(`// #import <stdlib.h>
// #import <stdint.h>
// #import <stdbool.h>
// #import <CoreGraphics/CGGeometry.h>`)
cw.WriteLineF(`// #import <stdlib.h>
// #import <stdint.h>
// #import <stdbool.h>
// #import "%s"`, m.Module.Header)
for _, f := range m.Functions {
f.WriteCSignature(&m.Module, cw)
}
Expand All @@ -185,6 +187,32 @@ func (m *ModuleWriter) WriteFunctions() {
}
}

// WriteFunctionWrappers writes the objc code to wrap exposed functions.
// The cgo type system is unaware of objective c types so these wrappers must exist to allow
// us to call the functions and return appropritely.
func (m *ModuleWriter) WriteFunctionWrappers() {
if len(m.Functions) == 0 {
return
}

filePath := filepath.Join(m.PlatformDir, m.Module.Package, "functions.gen.m")
os.MkdirAll(filepath.Dir(filePath), 0755)
f, err := os.Create(filePath)
if err != nil {
panic(err)
}
defer f.Close()

cw := &CodeWriter{Writer: f, IndentStr: "\t"}
cw.WriteLine(AutoGeneratedMark)

//TODO: determine appropriate imports
cw.WriteLineF("#import \"%s\"", m.Module.Header)
for _, f := range m.Functions {
f.WriteObjcWrapper(&m.Module, cw)
}
}

func (m *ModuleWriter) WriteEnumAliases() {
enums := make([]*AliasInfo, len(m.EnumAliases))
copy(enums, m.EnumAliases)
Expand Down
3 changes: 2 additions & 1 deletion generate/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"log"

"github.com/progrium/macdriver/generate/codegen"
"github.com/progrium/macdriver/generate/modules"
"github.com/progrium/macdriver/generate/typing"
)

Expand Down Expand Up @@ -253,7 +254,7 @@ func (db *Generator) ToFunction(fw string, sym Symbol) *codegen.Function {
}
fn := &codegen.Function{
Name: sym.Name,
GoName: sym.Name,
GoName: modules.TrimPrefix(sym.Name),
Description: sym.Description,
DocURL: sym.DocURL(),
Type: fntyp,
Expand Down
6 changes: 3 additions & 3 deletions generate/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ func (db *Generator) Generate(platform string, version int, rootDir string, fram
}
mw.Functions = append(mw.Functions, fn)
case "Struct":
fn := db.ToStruct(framework, s)
if fn == nil {
s := db.ToStruct(framework, s)
if s == nil {
continue
}
mw.Structs = append(mw.Structs, fn)
mw.Structs = append(mw.Structs, s)
}
}
mw.WriteCode()
Expand Down
7 changes: 1 addition & 6 deletions generate/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (

"github.com/progrium/macdriver/generate/codegen"
"github.com/progrium/macdriver/generate/modules"
"github.com/progrium/macdriver/generate/typing"
)

func (db *Generator) ToStruct(fw string, sym Symbol) *codegen.Struct {
Expand All @@ -16,16 +15,12 @@ func (db *Generator) ToStruct(fw string, sym Symbol) *codegen.Struct {
return nil
}
typ := db.TypeFromSymbol(sym)
styp, ok := typ.(*typing.StructType)
if !ok {
return nil
}
s := &codegen.Struct{
Name: sym.Name,
GoName: modules.TrimPrefix(sym.Name),
Description: sym.Description,
DocURL: sym.DocURL(),
Type: styp,
Type: typ,
}

return s
Expand Down
16 changes: 11 additions & 5 deletions generate/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ func (db *Generator) TypeFromSymbol(sym Symbol) typing.Type {
}
case "Union":
return &typing.RefType{
Name: sym.Name,
Name: sym.Name,
GName: modules.TrimPrefix(sym.Name),
}
case "Type":
if sym.Type != "Type Alias" {
Expand All @@ -66,7 +67,8 @@ func (db *Generator) TypeFromSymbol(sym Symbol) typing.Type {
// sym.Name == "NSZone" ||
sym.Name == "MusicSequence" {
return &typing.RefType{
Name: sym.Name,
Name: sym.Name,
GName: modules.TrimPrefix(sym.Name),
}
}
st, err := sym.Parse()
Expand All @@ -76,7 +78,8 @@ func (db *Generator) TypeFromSymbol(sym Symbol) typing.Type {
}
if st.Struct != nil {
return &typing.RefType{
Name: st.Struct.Name,
Name: st.Struct.Name,
GName: modules.TrimPrefix(sym.Name),
}
}
if st.TypeAlias == nil {
Expand All @@ -97,7 +100,9 @@ func (db *Generator) TypeFromSymbol(sym Symbol) typing.Type {
case "Struct":
if strings.HasSuffix(sym.Name, "Ref") {
return &typing.RefType{
Name: sym.Name,
Name: sym.Name,
GName: modules.TrimPrefix(sym.Name),
Module: modules.Get(module),
}
}
return &typing.StructType{
Expand All @@ -106,7 +111,8 @@ func (db *Generator) TypeFromSymbol(sym Symbol) typing.Type {
Module: modules.Get(module),
}
case "Function":
if sym.Name != "CGDisplayCreateImage" {
if sym.Name != "CGDisplayCreateImage" &&
sym.Name != "CGMainDisplayID" {
return nil
}
typ, err := sym.Parse()
Expand Down
18 changes: 12 additions & 6 deletions generate/typing/ref_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,23 @@ import (

// for weird struct refs like those ending in "Ref"
type RefType struct {
Name string // c and objc type name
// GName string // the go struct name
// Module *modules.Module // the module
Name string // c and objc type name
GName string // the go struct name
Module *modules.Module // the module
}

func (s *RefType) GoImports() set.Set[string] {
return set.New("unsafe")
if s.Module == nil {
return set.New("unsafe")
}
return set.New("github.com/progrium/macdriver/macos/" + s.Module.Package)
}

func (s *RefType) GoName(currentModule *modules.Module, receiveFromObjc bool) string {
return "unsafe.Pointer"
if s.Module == nil {
return "unsafe.Pointer"
}
return FullGoName(*s.Module, s.GName, *currentModule)
}

func (s *RefType) ObjcName() string {
Expand All @@ -30,5 +36,5 @@ func (s *RefType) CName() string {
}

func (s *RefType) DeclareModule() *modules.Module {
return nil
return s.Module
}

0 comments on commit 9df8d62

Please sign in to comment.