diff --git a/src/xmpp/client.go b/src/xmpp/client.go index 3d0c846..59264e7 100644 --- a/src/xmpp/client.go +++ b/src/xmpp/client.go @@ -75,11 +75,17 @@ func NewClientXMPP(jid JID, password string, config *ClientConfig) (*XMPP, error func startClient(stream *Stream, jid JID) error { - s := fmt.Sprintf( - "", - jid, - jid.Domain) - if err := stream.Send(s); err != nil { + start := xml.StartElement{ + xml.Name{"stream", "stream"}, + []xml.Attr{ + xml.Attr{xml.Name{"", "xmlns"}, "jabber:client"}, + xml.Attr{xml.Name{"xmlns", "stream"}, "http://etherx.jabber.org/streams"}, + xml.Attr{xml.Name{"", "from"}, jid.Full()}, + xml.Attr{xml.Name{"", "to"}, jid.Domain}, + }, + } + + if err := stream.SendStart(&start); err != nil { return err } @@ -103,10 +109,8 @@ func authenticate(stream *Stream, mechanisms []string, user, password string) er func authenticatePlain(stream *Stream, user, password string) error { - x := fmt.Sprintf( - "%s", - saslEncodePlain(user, password)) - if err := stream.Send(x); err != nil { + auth := saslAuth{Mechanism: "PLAIN", Message: saslEncodePlain(user, password)} + if err := stream.Send(&auth); err != nil { return err } @@ -123,6 +127,12 @@ func authenticatePlain(stream *Stream, user, password string) error { return nil } +type saslAuth struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl auth"` + Mechanism string `xml:"mechanism,attr"` + Message string `xml:",innerxml"` +} + func bindResource(stream *Stream, jid JID) (JID, error) { if jid.Resource == "" { return bindResourceServer(stream) diff --git a/src/xmpp/component.go b/src/xmpp/component.go index 5812024..8ac3554 100644 --- a/src/xmpp/component.go +++ b/src/xmpp/component.go @@ -5,7 +5,6 @@ import ( "encoding/xml" "errors" "fmt" - "log" ) // Create a component XMPP connection. @@ -30,10 +29,16 @@ func NewComponentXMPP(addr string, jid JID, secret string) (*XMPP, error) { func startComponent(stream *Stream, jid JID) (string, error) { - s := fmt.Sprintf( - "", - jid) - if err := stream.Send(s); err != nil { + start := xml.StartElement{ + xml.Name{"stream", "stream"}, + []xml.Attr{ + xml.Attr{xml.Name{"", "xmlns"}, "jabber:component:accept"}, + xml.Attr{xml.Name{"xmlns", "stream"}, "http://etherx.jabber.org/streams"}, + xml.Attr{xml.Name{"", "to"}, jid.Full()}, + }, + } + + if err := stream.SendStart(&start); err != nil { return "", err } @@ -63,9 +68,8 @@ func handshake(stream *Stream, streamId, secret string) error { hash.Write([]byte(secret)) // Send handshake. - s := fmt.Sprintf("%x", hash.Sum(nil)) - log.Println(s) - if err := stream.Send(s); err != nil { + handshake := saslHandshake{Value: fmt.Sprintf("%x", hash.Sum(nil))} + if err := stream.Send(&handshake); err != nil { return err } @@ -76,3 +80,8 @@ func handshake(stream *Stream, streamId, secret string) error { return nil } + +type saslHandshake struct { + XMLName xml.Name `xml:"handshake"` + Value string `xml:",innerxml"` +} diff --git a/src/xmpp/stream.go b/src/xmpp/stream.go index ab1024c..99253cb 100644 --- a/src/xmpp/stream.go +++ b/src/xmpp/stream.go @@ -1,6 +1,7 @@ package xmpp import ( + "bytes" "crypto/tls" "encoding/xml" "errors" @@ -19,8 +20,8 @@ type Stream struct { dec *xml.Decoder } -// Create a XML stream connection. See NewClientStream and NewComponentStream -// for something more useful. +// Create a XML stream connection. Typically handles the encoding and decoding +// of XML data for a higher-level API, e.g. XMPP. func NewStream(addr string) (*Stream, error) { log.Println("Connecting to", addr) @@ -43,7 +44,7 @@ func (stream *Stream) UpgradeTLS(config *tls.Config) error { log.Println("Upgrading to TLS") - if err := stream.Send(""); err != nil { + if err := stream.Send(&tlsStart{}); err != nil { return err } @@ -63,31 +64,50 @@ func (stream *Stream) UpgradeTLS(config *tls.Config) error { return nil } -func (stream *Stream) Send(v interface{}) error { +// Send the element's start tag. Typically used to open the stream's document. +func (stream *Stream) SendStart(start *xml.StartElement) error { - var bytes []byte - - switch v2 := v.(type) { - case []byte: - bytes = v2 - case string: - bytes = []byte(v2) - default: - b, err := xml.Marshal(v2) - if err != nil { + buf := new(bytes.Buffer) + if _, err := buf.Write([]byte{'<'}); err != nil { + return err + } + if err := writeXMLName(buf, start.Name); err != nil { + return err + } + for _, attr := range start.Attr { + if _, err := buf.Write([]byte{' '}); err != nil { + return err + } + if err := writeXMLAttr(buf, attr); err != nil { return err } - bytes = b } - - log.Println("send:", string(bytes)) - if _, err := stream.conn.Write(bytes); err != nil { + if _, err := buf.Write([]byte{'>'}); err != nil { return err } + return stream.send(buf.Bytes()) +} + +// Send a stanza. Used to write a complete, top-level element. +func (stream *Stream) Send(v interface{}) error { + bytes, err := xml.Marshal(v) + if err != nil { + return err + } + return stream.send(bytes) +} + +func (stream *Stream) send(b []byte) error { + log.Println("send:", string(b)) + if _, err := stream.conn.Write(b); err != nil { + return err + } return nil } +// Find start of next stanza. If match is not nil the stanza's XML name +// compared and must be equal. func (stream *Stream) Next(match *xml.Name) (*xml.StartElement, error) { for { t, err := stream.dec.Token() @@ -104,12 +124,20 @@ func (stream *Stream) Next(match *xml.Name) (*xml.StartElement, error) { panic("Unreachable") } -func (stream *Stream) Decode(i interface{}) error { - return stream.dec.Decode(i) +// Decode the next stanza. Works like xml.Unmarshal but reads from the stream's +// connection. +func (stream *Stream) Decode(v interface{}) error { + return stream.dec.Decode(v) } -func (stream *Stream) DecodeElement(i interface{}, se *xml.StartElement) error { - return stream.dec.DecodeElement(i, se) +// Decode the stanza with the given start element. Works like +// xml.Decoder.DecodeElement. +func (stream *Stream) DecodeElement(v interface{}, start *xml.StartElement) error { + return stream.dec.DecodeElement(v, start) +} + +type tlsStart struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"` } type tlsProceed struct { diff --git a/src/xmpp/xml.go b/src/xmpp/xml.go new file mode 100644 index 0000000..b09b354 --- /dev/null +++ b/src/xmpp/xml.go @@ -0,0 +1,36 @@ +package xmpp + +import ( + "encoding/xml" + "fmt" + "io" +) + +// Write a xml.Name. +func writeXMLName(w io.Writer, name xml.Name) error { + if name.Space == "" { + if _, err := fmt.Fprintf(w, name.Local); err != nil { + return err + } + } else { + if _, err := fmt.Fprintf(w, "%s:%s", name.Space, name.Local); err != nil { + return err + } + } + return nil +} + +// Write a xml.Attr. +func writeXMLAttr(w io.Writer, attr xml.Attr) error { + if err := writeXMLName(w, attr.Name); err != nil { + return err + } + if _, err := w.Write([]byte{'=', '\''}); err != nil { + return err + } + xml.Escape(w, []byte(attr.Value)) + if _, err := w.Write([]byte{'\''}); err != nil { + return err + } + return nil +}