Skip to content

Commit

Permalink
upload: add support for reading data from io.Reader
Browse files Browse the repository at this point in the history
Fixes #94
  • Loading branch information
tulir committed Aug 16, 2024
1 parent 82a2975 commit 82d17aa
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 4 deletions.
66 changes: 62 additions & 4 deletions upload.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
Expand All @@ -14,8 +14,10 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"

"go.mau.fi/util/random"

Expand Down Expand Up @@ -87,7 +89,43 @@ func (cli *Client) Upload(ctx context.Context, plaintext []byte, appInfo MediaTy
dataHash := sha256.Sum256(dataToUpload)
resp.FileEncSHA256 = dataHash[:]

err = cli.rawUpload(ctx, dataToUpload, resp.FileEncSHA256, appInfo, false, &resp)
err = cli.rawUpload(ctx, bytes.NewReader(dataToUpload), resp.FileEncSHA256, appInfo, false, &resp)
return
}

// UploadReader uploads the given attachment to WhatsApp servers.
//
// This is otherwise identical to [Upload], but it reads the plaintext from an [io.Reader] instead of a byte slice.
// A temporary file is required for the encryption process. If tempFile is nil, a temporary file will be created
// and deleted after the upload.
func (cli *Client) UploadReader(ctx context.Context, plaintext io.Reader, tempFile io.ReadWriteSeeker, appInfo MediaType) (resp UploadResponse, err error) {
resp.MediaKey = random.Bytes(32)
iv, cipherKey, macKey, _ := getMediaKeys(resp.MediaKey, appInfo)
if tempFile == nil {
tempFile, err = os.CreateTemp("", "whatsmeow-upload-*")
if err != nil {
err = fmt.Errorf("failed to create temporary file: %w", err)
return
}
fmt.Println("OPENED TEMPFILE", tempFile.(*os.File).Name())
defer func() {
tempFileFile := tempFile.(*os.File)
_ = tempFileFile.Close()
_ = os.Remove(tempFileFile.Name())
fmt.Println("REMOVED TEMPFILE", tempFile.(*os.File).Name())
}()
}
resp.FileSHA256, resp.FileEncSHA256, resp.FileLength, err = cbcutil.EncryptStream(cipherKey, iv, macKey, plaintext, tempFile)
if err != nil {
err = fmt.Errorf("failed to encrypt file: %w", err)
return
}
_, err = tempFile.Seek(0, io.SeekStart)
if err != nil {
err = fmt.Errorf("failed to seek to start of temporary file: %w", err)
return
}
err = cli.rawUpload(ctx, tempFile, resp.FileEncSHA256, appInfo, false, &resp)
return
}

Expand Down Expand Up @@ -125,11 +163,31 @@ func (cli *Client) UploadNewsletter(ctx context.Context, data []byte, appInfo Me
resp.FileLength = uint64(len(data))
hash := sha256.Sum256(data)
resp.FileSHA256 = hash[:]
err = cli.rawUpload(ctx, bytes.NewReader(data), resp.FileSHA256, appInfo, true, &resp)
return
}

// UploadNewsletterReader uploads the given attachment to WhatsApp servers without encrypting it first.
//
// This is otherwise identical to [Upload], but it reads the plaintext from an [io.Reader] instead of a byte slice.
// Unlike [UploadReader], this does not require a temporary file. However, the data needs to be hashed first,
// so an [io.ReadSeeker] is required to be able to read the data twice.
func (cli *Client) UploadNewsletterReader(ctx context.Context, data io.ReadSeeker, appInfo MediaType) (resp UploadResponse, err error) {
hasher := sha256.New()
var fileLength int64
fileLength, err = io.Copy(hasher, data)
resp.FileLength = uint64(fileLength)
resp.FileSHA256 = hasher.Sum(nil)
_, err = data.Seek(0, io.SeekStart)
if err != nil {
err = fmt.Errorf("failed to seek to start of data: %w", err)
return
}
err = cli.rawUpload(ctx, data, resp.FileSHA256, appInfo, true, &resp)
return
}

func (cli *Client) rawUpload(ctx context.Context, dataToUpload, fileHash []byte, appInfo MediaType, newsletter bool, resp *UploadResponse) error {
func (cli *Client) rawUpload(ctx context.Context, dataToUpload io.Reader, fileHash []byte, appInfo MediaType, newsletter bool, resp *UploadResponse) error {
mediaConn, err := cli.refreshMediaConn(false)
if err != nil {
return fmt.Errorf("failed to refresh media connections: %w", err)
Expand Down Expand Up @@ -168,7 +226,7 @@ func (cli *Client) rawUpload(ctx context.Context, dataToUpload, fileHash []byte,
RawQuery: q.Encode(),
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, uploadURL.String(), bytes.NewReader(dataToUpload))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, uploadURL.String(), dataToUpload)
if err != nil {
return fmt.Errorf("failed to prepare request: %w", err)
}
Expand Down
47 changes: 47 additions & 0 deletions util/cbcutil/cbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"errors"
"fmt"
"io"
)
Expand Down Expand Up @@ -99,3 +102,47 @@ func unpad(src []byte) ([]byte, error) {

return src[:(length - padLen)], nil
}

func EncryptStream(key, iv, macKey []byte, plaintext io.Reader, ciphertext io.Writer) ([]byte, []byte, uint64, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, nil, 0, fmt.Errorf("failed to create cipher: %w", err)
}
cbc := cipher.NewCBCEncrypter(block, iv)

plainHasher := sha256.New()
cipherHasher := sha256.New()
cipherMAC := hmac.New(sha256.New, macKey)
cipherMAC.Write(iv)

buf := make([]byte, 32*1024)
var size int
hasMore := true
for hasMore {
var n int
n, err = io.ReadFull(plaintext, buf)
plainHasher.Write(buf[:n])
size += n
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
padding := aes.BlockSize - size%aes.BlockSize
buf = append(buf[:n], bytes.Repeat([]byte{byte(padding)}, padding)...)
hasMore = false
} else if err != nil {
return nil, nil, 0, fmt.Errorf("failed to read file: %w", err)
}
cbc.CryptBlocks(buf, buf)
cipherMAC.Write(buf)
cipherHasher.Write(buf)
_, err = ciphertext.Write(buf)
if err != nil {
return nil, nil, 0, fmt.Errorf("failed to write file: %w", err)
}
}
mac := cipherMAC.Sum(nil)[:10]
cipherHasher.Write(mac)
_, err = ciphertext.Write(mac)
if err != nil {
return nil, nil, 0, fmt.Errorf("failed to write checksum to file: %w", err)
}
return plainHasher.Sum(nil), cipherHasher.Sum(nil), uint64(size), nil
}

0 comments on commit 82d17aa

Please sign in to comment.